spark, scala

Selecting Dynamic Columns In Spark DataFrames (aka Excluding Columns)

I often need to perform an inverse selection of columns in a dataframe, or exclude some columns from a query. This is a very easy method, and I use it frequently when arranging features into vectors for machine learning tasks.

import org.apache.spark.sql.Column

// Create an example dataframe
val dataDF = spark.createDataFrame(Seq(
  (1, 1, 2, 3, 8, 4, 5),
  (2, 4, 3, 8, 7, 9, 8),
  (3, 6, 1, 9, 2, 3, 6),
  (4, 10, 8, 6, 9, 4, 5),
  (5, 9, 2, 7, 10, 7, 3),
  (6, 1, 1, 4, 2, 8, 4)
)).toDF("colToExclude", "col1", "col2", "col3", "col4", "col5", "col6")

// Get an array of all columns in the dataframe, then
// filter out the columns you want to exclude from the 
// final dataframe.
val colsToSelect = dataDF.columns.filter(_ != "colToExclude")

// Take a look at the array as comma separate values.
colsToSelect.mkString(",")

// This method allows you to perform a simple selection
dataDF.select(colsToSelect.head, colsToSelect.tail: _*).show()

// This method creates a new dataframe using your column list
// Filter dataDF using the colsToSelect array, and map
// the results into columns.
dataDF.select(dataDF.columns.filter(colName => colsToSelect.contains(colName)).map(colName => new Column(colName)): _*).show()

In the simple selection method, note that we had to use colsToSelect.head and colsToSelect.tail: _*. The reason for this is that the overloaded dataframe.select() method for multiple columns requires at least 2 column names. If you just put in the array name without using .head and .tail, you'll get an overloaded method error.

Author image

About James Conner

Scuba dive master, wildlife photographer, anthropologist, programmer, electronics tinkerer and big data expert.