Home > Mobile >  How to aggregate the columns dynamically in spark scala?
How to aggregate the columns dynamically in spark scala?

Time:12-22

I have newly started working in spark-scala. I have a requirement where in I need to find the sum for few columns within a case statement. I've written the corresponding spark-sql code but unable to implement the same in spark-scala dynamically. Below is what I'm trying to achieve -

SQL Code-

Select  col_A,
        round(case when sum(amt_M)   <> 0.0 then sum(amt_M) 
                   when sum(amt_N)   <> 0.0 then sum(amt_N)
                   when sum(amt_P)   <> 0.0 then sum(amt_P) 
              end,1) as pct 
from table_T1
group by col_A

The use case is to get certain columns from a variable to implement the case-statement logic as above dynamically. Having said that, currently considering there are 3 columns however, that number could increase later on.

Below is the code I tried to implement in spark-scala -

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection._

val df = spark.table("database.table_T1")

val cols = "amt_M,amt_N,amt_P"

val aggCols = cols.split(",").toSeq

val sums = aggCols.map(colName => when(round(sum(colName).cast(DoubleType),1) =!= 0.0,sum(colName).cast(DoubleType).alias("sum_" colName)))

val df2 = df.groupBy(col("col_A")).agg(sums.head, sums.tail:_*)

However, this is not giving the desired results. Please help me on this.

Input Data

 -------- -------------------- --------------------- ---------------------- 
|col_A   |amt_M               |amt_N                |amt_P                 |
 -------- -------------------- --------------------- ---------------------- 
|5C-SVS-1|0.0                 |0.04064912622009295  |1.6256888829356116E-4 |
|5C-SVS-1|0.0                 |0.026542159153759487 |8.574900251977566E-4  |
|5C-SVS-1|0.0                 |5.703894148377958E-5 |1.0745888408402782E-7 |
|5C-SVS-1|0.0                 |0.0                  |4.514561031069833E-4  |
|5C-SVS-1|0.0                 |0.011794053124022862 |0.0020388259536434656 |
|5C-SVS-1|0.0                 |7.55793849084569E-4  |0.0017105736019335327 |
|5C-SVS-1|0.0                 |0.019303776946698548 |2.240625765755109E-5  |
|5C-SVS-1|0.0                 |-8.028117213883126E-6|-2.1979360825171534E-6|
|5C-SVS-1|0.001940948839163001|0.029163686986129422 |0.09505621692309557   |
|5C-SVS-1|0.0                 |2.515835289984397E-7 |1.1486227577926157E-8 |
|5C-SVS-1|0.0                 |0.007844299114837874 |9.974187712854785E-4  |
|5C-SVS-1|0.0                 |5.033123682586349E-4 |1.3644443189731007E-4 |
|5C-SVS-1|0.0                 |0.026331681277001386 |6.022434166108063E-4  |
|5C-SVS-1|0.0                 |8.098023638080503E-6 |1.0                   |
|5C-SVS-1|0.0                 |0.03655893437209876  |0.003113370686486882  |
|5C-SVS-1|0.0                 |0.01409363925733864  |6.239415097038338E-4  |
|5C-SVS-1|0.0                 |0.02171856350557304  |0.0                   |
|5C-SVS-1|0.008435341548288601|0.03347191686227869  |0.35221710556006247   |
|5C-SVS-1|0.0                 |-2.547132732700875E-6|-0.13073525789233997  |
|5C-SVS-1|0.006057441518729214|0.024036273783621134 |0.21447606070652467   |
 -------- -------------------- --------------------- ---------------------- 

Expected Output

 -------- --- 
|   col_A|pct|
 -------- --- 
|5C-SVS-1|1.0|
 -------- --- 

Thanks

CodePudding user response:

You could first groupBy your Dataframe on col_A, calculate the sums and afterwards use a map operation to choose which sum you want to take with you. Something like this:

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

// Creating the necessary schema to control the types read in when reading in our CSV
val schema = new StructType()
    .add("col_A", StringType)
    .add("amt_M", DoubleType)
    .add("amt_N", DoubleType)
    .add("amt_P", DoubleType)

// Reading in the Dataframe using our premade schema. I put the data in a CSV
// file with ; as delimiters.
val df = spark.read
    .option("header", "true")
    .option("sep",";")
    .schema(schema)
    .csv("./someData.csv")

df.show
 -------- -------------------- -------------------- --------------------                                                                                                                                                                                                        
|   col_A|               amt_M|               amt_N|               amt_P|                                                                                                                                                                                                       
 -------- -------------------- -------------------- --------------------                                                                                                                                                                                                        
|5C-SVS-1|                 0.0| 0.04064912622009295|1.625688882935611...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|0.026542159153759487|8.574900251977566E-4|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|5.703894148377958E-5|1.074588840840278...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|                 0.0|4.514561031069833E-4|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|0.011794053124022862|0.002038825953643...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0| 7.55793849084569E-4|0.001710573601933...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|0.019303776946698548|2.240625765755109E-5|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|-8.02811721388312...|-2.19793608251715...|                                                                                                                                                                                                       
|5C-SVS-1|0.001940948839163001|0.029163686986129422| 0.09505621692309557|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|2.515835289984397E-7|1.148622757792615...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|0.007844299114837874|9.974187712854785E-4|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|5.033123682586349E-4|1.364444318973100...|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|0.026331681277001386|6.022434166108063E-4|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|8.098023638080503E-6|                 1.0|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0| 0.03655893437209876|0.003113370686486882|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0| 0.01409363925733864|6.239415097038338E-4|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0| 0.02171856350557304|                 0.0|                                                                                                                                                                                                       
|5C-SVS-1|0.008435341548288601| 0.03347191686227869| 0.35221710556006247|                                                                                                                                                                                                       
|5C-SVS-1|                 0.0|-2.54713273270087...|-0.13073525789233997|                                                                                                                                                                                                       
|5C-SVS-1|0.006057441518729214|0.024036273783621134| 0.21447606070652467|                                                                                                                                                                                                       
 -------- -------------------- -------------------- -------------------- 

// Aggregating our data for each distinct value in col_A, summing all the amt columns
val aggregated_df = df.groupBy(col("col_A"))
    .agg(
        round(sum(col("amt_M")).as("amt_M_sum"), 1),
        round(sum(col("amt_N")).as("amt_N_sum"), 1),
        round(sum(col("amt_P")).as("amt_P_sum"), 1)
)

aggregated_df.show                                                                                                                                                                                                                                                       
 -------- --------------------------------- --------------------------------- ---------------------------------                                                                                                                                                                 
|   col_A|round(sum(amt_M) AS amt_M_sum, 1)|round(sum(amt_N) AS amt_N_sum, 1)|round(sum(amt_P) AS amt_P_sum, 1)|                                                                                                                                                                
 -------- --------------------------------- --------------------------------- ---------------------------------                                                                                                                                                                 
|5C-SVS-1|                              0.0|                              0.3|                              1.5|                                                                                                                                                                
 -------- --------------------------------- --------------------------------- --------------------------------- 


// Selecting our wanted values. We make use of Scala pattern matching here to
// easily deconstruct our data and make something readable
val output = aggregated_df.map(
    row => row match {
        case Row(col_A: String, sum_amt_M: Double, sum_amt_N: Double, sum_amt_P: Double) => {
            if (sum_amt_M != 0.0)
                (col_A, sum_amt_M)
            else if (sum_amt_N != 0.0)
                (col_A, sum_amt_N)
            else
                (col_A, sum_amt_P)
        }
    }
).toDF("col_A", "pct")

output.show                                                                                                                                                                                                                                                              
 -------- ---                                                                                                                                                                                                                                                                   
|   col_A|pct|                                                                                                                                                                                                                                                                  
 -------- ---                                                                                                                                                                                                                                                                   
|5C-SVS-1|0.3|                                                                                                                                                                                                                                                                  
 -------- --- 

Note: What do you do if all of the sums == 0? That's up to you to decide: I put the value of sum_amt_P as the else catch-all case. But from here on you can just adapt the logic inside of the map function to get whatever you want.

Hope this helps!

CodePudding user response:

I solved the requirement by implementing the below method -

import org.apache.spark.sql.types._

def getSumCols(columnList: List[String]): Column = {

// Storing the value for the 1st index 

    var conditionColumn: Column = when(sum(col(columnList(0)).cast(DoubleType)) =!= 0.0, sum(col(columnList(0)).cast(DoubleType)))

// Iterating through the 2nd element till end and appending to existing variable created in the 1st step

    for(c <- 1 to columnList.length -1){
        conditionColumn = conditionColumn.when( sum(col(columnList(c)).cast(DoubleType)) =!= 0.0, sum(col(columnList(c)).cast(DoubleType)) )
    }
    round(conditionColumn,1)
}

Now once this is being called over during the aggregation as below -

val cols = "amt_M,amt_N,amt_P"

val colList = cols.split(",").toList

val conditionColumn: Column = getSumCols(colList)

val df1 = df.groupBy("col_A").agg(conditionColumn.alias("pct"))
  • Related