Home > database >  Transpose DataFrame Spark Scala
Transpose DataFrame Spark Scala

Time:08-11

Task:

I need to transpose the dataframe. I have written some solution, but I wish to achieve better performance.

INPUT:

val columnsNames = List(col_name1, col_name2)

DataFrame:
 ----------- ------------------- ------------------- ------------------- ------------------- -------------- 
|period_date|col_name1#max_value|col_name1#min_value|col_name2#max_value|col_name2#min_value|period_last_dt|
 ----------- ------------------- ------------------- ------------------- ------------------- -------------- 
| 2022-02-28|               12.0|               12.0|               22.0|               22.0|    2022-02-28|
| 2022-01-31|               11.0|               11.0|               21.0|               21.0|    2022-01-31|
| 2022-03-31|               13.0|               13.0|               23.0|               23.0|    2022-03-31|
 ----------- ------------------- ------------------- ------------------- ------------------- -------------- 

OUTPUT:

DataFrame:
 -------------- ----------- --------- --------- 
|period_last_dt|column_name|max_value|min_value|
 -------------- ----------- --------- --------- 
|    2022-02-28|  col_name1|     12.0|     12.0|
|    2022-02-28|  col_name2|     22.0|     22.0|
|    2022-01-31|  col_name1|     11.0|     11.0|
|    2022-01-31|  col_name2|     21.0|     21.0|
|    2022-03-31|  col_name1|     13.0|     13.0|
|    2022-03-31|  col_name2|     23.0|     23.0|
 -------------- ----------- --------- --------- 

My solution:

https://scastie.scala-lang.org/DQleVDXaSlCKWCpnNWNayA

// Structure for the resulting dataset.
case class Structure(period_last_dt: String, column_name: String, max_value: Double, min_value: Double)


class RealStatistic(spark: SparkSession) {

  import spark.implicits._
  
  val columnsNames = List("col_name1", "col_name2")

  val inputDf =  Seq(
                     ("2022-02-28", 12.0, 12.0, 22.0, 22.0, "2022-02-28"),
                     ("2022-01-31", 11.0, 11.0, 21.0, 21.0, "2022-01-31"),
                     ("2022-03-31", 13.0, 13.0, 23.0, 23.0, "2022-03-31")
                    ).toDF("period_date", "col_name1#max_value", "col_name1#min_value", "col_name2#max_value", "col_name2#min_value", "period_last_dt")
  inputDf.show()


// call collect !!!
  val resultDf =
    inputDf.collect.map(row => realStatisticsOn(columnsNames, row))
                   .reduceOption(_ union _)
                    .getOrElse(List.empty[Structure])
                   .toDF()


  resultDf.show()


  def realStatisticsOn(columns: List[String], row: Row): List[Structure] =
     columns.map(name => realStatisticOn(name, row))


  
  def realStatisticOn(column: String, row: Row): Structure =
      Structure(
        period_last_dt = row.getAs[String]("period_last_dt"),
        column_name = column,
        max_value = row.getAs[Double](s"${column}#max_value"),
        min_value = row.getAs[Double](s"${column}#min_value")
      )
  
}

Problem:

In my solution, I use a call to the collect method. I would like to avoid calling this function. I need help or a hint.

CodePudding user response:

There are functions like "pivot" and "unpivot" to transpose a dataframe but looking at your task specifically, you can do the following:

  1. For each of the column names that you want in the row, create a dataframe with the said column name as a value (i.e values in "column_name" column).

df1 for "col_name1", df2 for "col_name2", .. and so on

  1. Union all the dataframes created in the process and then order by accordingly.

Sharing the code that worked for me (extension of your code):

case class Structure(period_last_dt: String, column_name: String, max_value: Double, min_value: Double)

val result_df = Seq.empty[Structure].toDF // Creating an empty dataframe with the required structure
val columnsNames = List("col_name1", "col_name2")

val result = columnsNames.foldLeft(result_df){ (df, colName) =>
    df.union(inputDf.withColumn("column_name", lit(colName)).
      select(col("period_last_dt"), 
      col("column_name"), 
      col(s"$colName#max_value").as("max_value"), 
      col(s"$colName#min_value").as("min_value"),
    ))
}
  • Related