Home > Software engineering >  Retrieve column value given a column of column names (spark / scala)
Retrieve column value given a column of column names (spark / scala)

Time:06-23

I have a dataframe like the following:

 ----------- ----------- --------------- ------ ---------------------                                         
|best_col   |A          |B              |  C   |<many more columns>  |
 ----------- ----------- --------------- ------ --------------------- 
|     A     |    14     |        26     |  32  |       ...           |
|     C     |    13     |        17     |  96  |       ...           |
|     B     |    23     |        19     |  42  |       ...           |
 ----------- ----------- --------------- ------ ---------------------  

I want to end up with a DataFrame like this:

 ----------- ----------- --------------- ------ --------------------- ----------                                         
|best_col   |A          |B              |  C   |<many more columns>  | result   |
 ----------- ----------- --------------- ------ --------------------- ---------- 
|     A     |    14     |        26     |  32  |       ...           |   14     |
|     C     |    13     |        17     |  96  |       ...           |   96     |
|     B     |    23     |        19     |  42  |       ...           |   19     |
 ----------- ----------- --------------- ------ --------------------- ---------- 

Essentially, I want to add a column result that will choose the value from the column specified in the best_col column. best_col only contains column names that are present in the DataFrame. Since I have dozens of columns, I want to avoid using a bunch of when statements to check when col(best_col) === A etc. I tried doing col(col("best_col").toString()), but this didn't work. Is there an easy way to do this?

CodePudding user response:

Using map_filter introduced in Spark 3.0:

val df = Seq(
    ("A", 14, 26, 32),
    ("C", 13, 17, 96),
    ("B", 23, 19, 42),
).toDF("best_col", "A", "B", "C")

df.withColumn("result", map(df.columns.tail.flatMap(c => Seq(col(c), lit(col("best_col") === lit(c)))): _*))
    .withColumn("result", map_filter(col("result"), (a, b) => b))
    .withColumn("result", map_keys(col("result"))(0))
    .show()

 -------- --- --- --- ------ 
|best_col|  A|  B|  C|result|
 -------- --- --- --- ------ 
|       A| 14| 26| 32|    14|
|       C| 13| 17| 96|    96|
|       B| 23| 19| 42|    19|
 -------- --- --- --- ------ 
  • Related