I have newly started working in spark-scala
. I have a requirement where in I need to find the sum
for few columns within a case statement. I've written the corresponding spark-sql
code but unable to implement the same in spark-scala
dynamically. Below is what I'm trying to achieve -
SQL Code-
Select col_A,
round(case when sum(amt_M) <> 0.0 then sum(amt_M)
when sum(amt_N) <> 0.0 then sum(amt_N)
when sum(amt_P) <> 0.0 then sum(amt_P)
end,1) as pct
from table_T1
group by col_A
The use case is to get certain columns from a variable to implement the case-statement
logic as above dynamically. Having said that, currently considering there are 3 columns however, that number could increase later on.
Below is the code I tried to implement in spark-scala
-
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection._
val df = spark.table("database.table_T1")
val cols = "amt_M,amt_N,amt_P"
val aggCols = cols.split(",").toSeq
val sums = aggCols.map(colName => when(round(sum(colName).cast(DoubleType),1) =!= 0.0,sum(colName).cast(DoubleType).alias("sum_" colName)))
val df2 = df.groupBy(col("col_A")).agg(sums.head, sums.tail:_*)
However, this is not giving the desired results. Please help me on this.
Input Data
-------- -------------------- --------------------- ----------------------
|col_A |amt_M |amt_N |amt_P |
-------- -------------------- --------------------- ----------------------
|5C-SVS-1|0.0 |0.04064912622009295 |1.6256888829356116E-4 |
|5C-SVS-1|0.0 |0.026542159153759487 |8.574900251977566E-4 |
|5C-SVS-1|0.0 |5.703894148377958E-5 |1.0745888408402782E-7 |
|5C-SVS-1|0.0 |0.0 |4.514561031069833E-4 |
|5C-SVS-1|0.0 |0.011794053124022862 |0.0020388259536434656 |
|5C-SVS-1|0.0 |7.55793849084569E-4 |0.0017105736019335327 |
|5C-SVS-1|0.0 |0.019303776946698548 |2.240625765755109E-5 |
|5C-SVS-1|0.0 |-8.028117213883126E-6|-2.1979360825171534E-6|
|5C-SVS-1|0.001940948839163001|0.029163686986129422 |0.09505621692309557 |
|5C-SVS-1|0.0 |2.515835289984397E-7 |1.1486227577926157E-8 |
|5C-SVS-1|0.0 |0.007844299114837874 |9.974187712854785E-4 |
|5C-SVS-1|0.0 |5.033123682586349E-4 |1.3644443189731007E-4 |
|5C-SVS-1|0.0 |0.026331681277001386 |6.022434166108063E-4 |
|5C-SVS-1|0.0 |8.098023638080503E-6 |1.0 |
|5C-SVS-1|0.0 |0.03655893437209876 |0.003113370686486882 |
|5C-SVS-1|0.0 |0.01409363925733864 |6.239415097038338E-4 |
|5C-SVS-1|0.0 |0.02171856350557304 |0.0 |
|5C-SVS-1|0.008435341548288601|0.03347191686227869 |0.35221710556006247 |
|5C-SVS-1|0.0 |-2.547132732700875E-6|-0.13073525789233997 |
|5C-SVS-1|0.006057441518729214|0.024036273783621134 |0.21447606070652467 |
-------- -------------------- --------------------- ----------------------
Expected Output
-------- ---
| col_A|pct|
-------- ---
|5C-SVS-1|1.0|
-------- ---
Thanks
CodePudding user response:
You could first groupBy
your Dataframe on col_A
, calculate the sums and afterwards use a map
operation to choose which sum you want to take with you. Something like this:
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
// Creating the necessary schema to control the types read in when reading in our CSV
val schema = new StructType()
.add("col_A", StringType)
.add("amt_M", DoubleType)
.add("amt_N", DoubleType)
.add("amt_P", DoubleType)
// Reading in the Dataframe using our premade schema. I put the data in a CSV
// file with ; as delimiters.
val df = spark.read
.option("header", "true")
.option("sep",";")
.schema(schema)
.csv("./someData.csv")
df.show
-------- -------------------- -------------------- --------------------
| col_A| amt_M| amt_N| amt_P|
-------- -------------------- -------------------- --------------------
|5C-SVS-1| 0.0| 0.04064912622009295|1.625688882935611...|
|5C-SVS-1| 0.0|0.026542159153759487|8.574900251977566E-4|
|5C-SVS-1| 0.0|5.703894148377958E-5|1.074588840840278...|
|5C-SVS-1| 0.0| 0.0|4.514561031069833E-4|
|5C-SVS-1| 0.0|0.011794053124022862|0.002038825953643...|
|5C-SVS-1| 0.0| 7.55793849084569E-4|0.001710573601933...|
|5C-SVS-1| 0.0|0.019303776946698548|2.240625765755109E-5|
|5C-SVS-1| 0.0|-8.02811721388312...|-2.19793608251715...|
|5C-SVS-1|0.001940948839163001|0.029163686986129422| 0.09505621692309557|
|5C-SVS-1| 0.0|2.515835289984397E-7|1.148622757792615...|
|5C-SVS-1| 0.0|0.007844299114837874|9.974187712854785E-4|
|5C-SVS-1| 0.0|5.033123682586349E-4|1.364444318973100...|
|5C-SVS-1| 0.0|0.026331681277001386|6.022434166108063E-4|
|5C-SVS-1| 0.0|8.098023638080503E-6| 1.0|
|5C-SVS-1| 0.0| 0.03655893437209876|0.003113370686486882|
|5C-SVS-1| 0.0| 0.01409363925733864|6.239415097038338E-4|
|5C-SVS-1| 0.0| 0.02171856350557304| 0.0|
|5C-SVS-1|0.008435341548288601| 0.03347191686227869| 0.35221710556006247|
|5C-SVS-1| 0.0|-2.54713273270087...|-0.13073525789233997|
|5C-SVS-1|0.006057441518729214|0.024036273783621134| 0.21447606070652467|
-------- -------------------- -------------------- --------------------
// Aggregating our data for each distinct value in col_A, summing all the amt columns
val aggregated_df = df.groupBy(col("col_A"))
.agg(
round(sum(col("amt_M")).as("amt_M_sum"), 1),
round(sum(col("amt_N")).as("amt_N_sum"), 1),
round(sum(col("amt_P")).as("amt_P_sum"), 1)
)
aggregated_df.show
-------- --------------------------------- --------------------------------- ---------------------------------
| col_A|round(sum(amt_M) AS amt_M_sum, 1)|round(sum(amt_N) AS amt_N_sum, 1)|round(sum(amt_P) AS amt_P_sum, 1)|
-------- --------------------------------- --------------------------------- ---------------------------------
|5C-SVS-1| 0.0| 0.3| 1.5|
-------- --------------------------------- --------------------------------- ---------------------------------
// Selecting our wanted values. We make use of Scala pattern matching here to
// easily deconstruct our data and make something readable
val output = aggregated_df.map(
row => row match {
case Row(col_A: String, sum_amt_M: Double, sum_amt_N: Double, sum_amt_P: Double) => {
if (sum_amt_M != 0.0)
(col_A, sum_amt_M)
else if (sum_amt_N != 0.0)
(col_A, sum_amt_N)
else
(col_A, sum_amt_P)
}
}
).toDF("col_A", "pct")
output.show
-------- ---
| col_A|pct|
-------- ---
|5C-SVS-1|0.3|
-------- ---
Note: What do you do if all of the sums == 0? That's up to you to decide: I put the value of sum_amt_P
as the else
catch-all case. But from here on you can just adapt the logic inside of the map
function to get whatever you want.
Hope this helps!
CodePudding user response:
I solved the requirement by implementing the below method -
import org.apache.spark.sql.types._
def getSumCols(columnList: List[String]): Column = {
// Storing the value for the 1st index
var conditionColumn: Column = when(sum(col(columnList(0)).cast(DoubleType)) =!= 0.0, sum(col(columnList(0)).cast(DoubleType)))
// Iterating through the 2nd element till end and appending to existing variable created in the 1st step
for(c <- 1 to columnList.length -1){
conditionColumn = conditionColumn.when( sum(col(columnList(c)).cast(DoubleType)) =!= 0.0, sum(col(columnList(c)).cast(DoubleType)) )
}
round(conditionColumn,1)
}
Now once this is being called over during the aggregation as below -
val cols = "amt_M,amt_N,amt_P"
val colList = cols.split(",").toList
val conditionColumn: Column = getSumCols(colList)
val df1 = df.groupBy("col_A").agg(conditionColumn.alias("pct"))