spark sql Find the number of extensions for a record


I have a dataset as below

col1 extension_col1
2345 2246
2246 2134
2134 2091
2091 Null
1234 1111
1111 Null

I need to find the number of extensions available for each record in col1 whereby records are sorted already and contiguously in terms of sets which are terminated by a null.

the final result as below

col1 extension_col1 No_Of_Extensions
2345 2246 3
2246 2134 2
2134 2091 1
2091 Null 0
1234 1111 1
1111 Null 0

value 2345 extends as 2345>2246>2134>2091>null and hence it has 3 extension relations excluding null.

How to get the 3rd column(No_Of_Extensions) using spark sql/scala?

You can achieve that using some Window functions. First, using a cumulative conditional sum on extension_col1, create a group column grp. Then, using row_number function on a window partitioned by grp and ordered by col1 but this time on ascending you get the desired result:

import org.apache.spark.sql.expressions.Window

val w1 = Window.orderBy(desc("col1"))
val w2 = Window.partitionBy("grp").orderBy("col1")

val result = df.withColumn(
    sum(when(col("extension_col1").isNull, 1).otherwise(0)).over(w1)
    when(col("extension_col1").isNull, 0).otherwise(row_number().over(w2))

// ---- -------------- ---------------- 
// ---- -------------- ---------------- 
//|2345|          2246|               3|
//|2246|          2134|               2|
//|2134|          2091|               1|
//|2091|          null|               0|
//|1234|          1111|               1|
//|1111|          null|               0|
// ---- -------------- ---------------- 

Note that the first sum is using a non partitioned window, so all the data will be moved into one partition and thus could affect performances.

Spark-SQL equivalent query:

SELECT col1, 
       case when extension_col1 is null then 0 else row_number() over(partition by grp order by col1) end as No_Of_Extensions
      SELECT *, 
             sum(case when extension_col1 is null then 1 else 0 end) over(order by col1 desc) as grp
      FROM df

An alternative to the blackbishop in that I assume the data may not always be ordered, and hence do some alternative processing. I like the conditional summing, but not applicable here.

In all honesty a bad use case for Spark at scale as I can also not get around single partition aspect either as the other answer states. But partitioning size is increased on newer Spark versions and may the ' lists' are long in this example.

Part 1 - Generate data

// 1. Generate data.
val df = Seq(( Some(2345), Some(22246) ), ( Some(22246), Some(2134) ), ( Some(2134), Some(2091) ), (Some(2091), None) ,
              ( Some(1234), Some(1111) ), ( Some(1111), None )
             ).toDF("col1" ,"extCol1")

Part 2 - Actual processing

//2. Narrow transform, add position in dataset as values nay not awlays be desc or asc.
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField,StructType,IntegerType, ArrayType, LongType}
val newSchema = StructType(df.schema.fields    Array(StructField("rowid", LongType, false)))
val rdd = df.rdd.zipWithIndex
val df2 = spark.createDataFrame(rdd.map{ case (row, index) => Row.fromSeq(row.toSeq    Array(index))}, newSchema)  // Some cost

//3. Make groupings in record ranges. Cannot avoid the single partition aspects, so this only works if we can do it with data that can fit into a single partition. At scale one would 
//   not be able to do this really unless some grouping characteristic. 
val dfg = df2.filter(df2("extCol1").isNull)

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val winSpec1 = Window.orderBy(asc("rowid"))

val dfg2 = dfg.withColumn("prev_rowid_tmp", lag("rowid", 1, -1).over(winSpec1))
              .withColumn("rowidFrom", $"prev_rowid_tmp"   1)

//4. Apply grouping of ranges of rows to data.
val df3 = df2.as("df2").join(dfg2.as("dfg2"), 
          $"df2.rowid" >= $"dfg2.rowidFrom" && $"df2.rowid" <= $"dfg2.rowidTo", "inner")             

//5. Do the calcs.
val res = df3.withColumn("numExtensions", $"rowidTo" - $"rowid") 
res.select("df2.col1", "extCol1", "numExtensions").show(false)


 ----- ------- ------------- 
|col1 |extCol1|numExtensions|
 ----- ------- ------------- 
|2345 |22246  |3            |
|22246|2134   |2            |
|2134 |2091   |1            |
|2091 |null   |0            |
|1234 |1111   |1            |
|1111 |null   |0            |
 ----- ------- ------------- 
