I often need to perform an inverse selection of columns in a dataframe, or exclude some columns from a query. This is a very easy method, and I use it frequently when arranging features into vectors for machine learning tasks.
import org.apache.spark.sql.Column
// Create an example dataframe
val dataDF = spark.createDataFrame(Seq(
(1, 1, 2, 3, 8, 4, 5),
(2, 4, 3, 8, 7, 9, 8),
(3, 6, 1, 9, 2, 3, 6),
(4, 10, 8, 6, 9, 4, 5),
(5, 9, 2, 7, 10, 7, 3),
(6, 1, 1, 4, 2, 8, 4)
)).toDF("colToExclude", "col1", "col2", "col3", "col4", "col5", "col6")
// Get an array of all columns in the dataframe, then
// filter out the columns you want to exclude from the
// final dataframe.
val colsToSelect = dataDF.columns.filter(_ != "colToExclude")
// Take a look at the array as comma separate values.
colsToSelect.mkString(",")
// This method allows you to perform a simple selection
dataDF.select(colsToSelect.head, colsToSelect.tail: _*).show()
// This method creates a new dataframe using your column list
// Filter dataDF using the colsToSelect array, and map
// the results into columns.
dataDF.select(dataDF.columns.filter(colName => colsToSelect.contains(colName)).map(colName => new Column(colName)): _*).show()
In the simple selection method, note that we had to use colsToSelect.head
and colsToSelect.tail: _*
. The reason for this is that the overloaded dataframe.select()
method for multiple columns requires at least 2 column names. If you just put in the array name without using .head
and .tail
, you'll get an overloaded method error.