I have a dataset as below
col1 | extension_col1 |
---|---|
2345 | 2246 |
2246 | 2134 |
2134 | 2091 |
2091 | Null |
1234 | 1111 |
1111 | Null |
I need to find the number of extensions available for each record in col1
whereby records are sorted already and contiguously in terms of sets which are terminated by a null
.
the final result as below
col1 | extension_col1 | No_Of_Extensions |
---|---|---|
2345 | 2246 | 3 |
2246 | 2134 | 2 |
2134 | 2091 | 1 |
2091 | Null | 0 |
1234 | 1111 | 1 |
1111 | Null | 0 |
value 2345 extends as 2345>2246>2134>2091>null
and hence it has 3 extension relations excluding null.
How to get the 3rd column(No_Of_Extensions
) using spark sql/scala?
CodePudding user response:
You can achieve that using some Window functions. First, using a cumulative conditional sum on extension_col1
, create a group column grp
. Then, using row_number
function on a window partitioned by grp
and ordered by col1
but this time on ascending you get the desired result:
import org.apache.spark.sql.expressions.Window
val w1 = Window.orderBy(desc("col1"))
val w2 = Window.partitionBy("grp").orderBy("col1")
val result = df.withColumn(
"grp",
sum(when(col("extension_col1").isNull, 1).otherwise(0)).over(w1)
).withColumn(
"No_Of_Extensions",
when(col("extension_col1").isNull, 0).otherwise(row_number().over(w2))
).drop("grp")
result.show
// ---- -------------- ----------------
//|col1|extension_col1|No_Of_Extensions|
// ---- -------------- ----------------
//|2345| 2246| 3|
//|2246| 2134| 2|
//|2134| 2091| 1|
//|2091| null| 0|
//|1234| 1111| 1|
//|1111| null| 0|
// ---- -------------- ----------------
Note that the first sum
is using a non partitioned window, so all the data will be moved into one partition and thus could affect performances.
Spark-SQL equivalent query:
SELECT col1,
extension_col1,
case when extension_col1 is null then 0 else row_number() over(partition by grp order by col1) end as No_Of_Extensions
FROM (
SELECT *,
sum(case when extension_col1 is null then 1 else 0 end) over(order by col1 desc) as grp
FROM df
)
CodePudding user response:
An alternative to the blackbishop in that I assume the data may not always be ordered, and hence do some alternative processing. I like the conditional summing, but not applicable here.
In all honesty a bad use case for Spark at scale as I can also not get around single partition aspect either as the other answer states. But partitioning size is increased on newer Spark versions and may the ' lists' are long in this example.
Part 1 - Generate data
// 1. Generate data.
val df = Seq(( Some(2345), Some(22246) ), ( Some(22246), Some(2134) ), ( Some(2134), Some(2091) ), (Some(2091), None) ,
( Some(1234), Some(1111) ), ( Some(1111), None )
).toDF("col1" ,"extCol1")
Part 2 - Actual processing
//2. Narrow transform, add position in dataset as values nay not awlays be desc or asc.
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField,StructType,IntegerType, ArrayType, LongType}
val newSchema = StructType(df.schema.fields Array(StructField("rowid", LongType, false)))
val rdd = df.rdd.zipWithIndex
val df2 = spark.createDataFrame(rdd.map{ case (row, index) => Row.fromSeq(row.toSeq Array(index))}, newSchema) // Some cost
//3. Make groupings in record ranges. Cannot avoid the single partition aspects, so this only works if we can do it with data that can fit into a single partition. At scale one would
// not be able to do this really unless some grouping characteristic.
val dfg = df2.filter(df2("extCol1").isNull)
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val winSpec1 = Window.orderBy(asc("rowid"))
val dfg2 = dfg.withColumn("prev_rowid_tmp", lag("rowid", 1, -1).over(winSpec1))
.withColumn("rowidFrom", $"prev_rowid_tmp" 1)
.drop("prev_rowid_tmp")
.drop("extCol1")
.withColumnRenamed("rowid","rowidTo")
//4. Apply grouping of ranges of rows to data.
val df3 = df2.as("df2").join(dfg2.as("dfg2"),
$"df2.rowid" >= $"dfg2.rowidFrom" && $"df2.rowid" <= $"dfg2.rowidTo", "inner")
//5. Do the calcs.
val res = df3.withColumn("numExtensions", $"rowidTo" - $"rowid")
res.select("df2.col1", "extCol1", "numExtensions").show(false)
returns:
----- ------- -------------
|col1 |extCol1|numExtensions|
----- ------- -------------
|2345 |22246 |3 |
|22246|2134 |2 |
|2134 |2091 |1 |
|2091 |null |0 |
|1234 |1111 |1 |
|1111 |null |0 |
----- ------- -------------