Home > front end >  Assigning parent to spark row
Assigning parent to spark row

Time:08-28

I have a spark dataframe like this:

val dataSeq = Seq(
    (20, "data", 0),
    (21, "data", 1),
    (99, "data", 2),
    (99, "data", 3),
    (99, "data", 4),
    (25, "data", 5),
    (99, "data", 6),
    (99, "data", 7))

val df = dataSeq.toDF("A", "B", "index")
df.show
 --- ---- ----- 
|  A|   B|index|
 --- ---- ----- 
| 20|data|    0|
| 21|data|    1|
| 99|data|    2|
| 99|data|    3|
| 99|data|    4|
| 25|data|    5|
| 99|data|    6|
| 99|data|    7|
 --- ---- ----- 

In this case, the 99s are children of the non-99 code before the first instance of it - i.e. indices 2,3,4 are children of index 1 and indices 6,7 are children of index 5. I want to create a df consisting of only 99 codes, but with an additional column for the parent index in order to group them later

I've separated the dataframes by code using the following:

val code20Df = df.filter(df("A") === 20)
val code21Df = df.filter(df("A") === 21)
val code25Df = df.filter(df("A") === 25)
val code99Df = df.filter(df("A") === 99)

I want the code99Df to look like:

A B index parentIndex
99 data 2 1
99 data 3 1
99 data 4 1
99 data 6 5
99 data 7 5

Is there an aggregate function that will allow me to keep track of the "parent", assign that value to children and then reset that value once the row value is not 99, or will I have to resort to using a for loop?

CodePudding user response:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val w = Window.orderBy("index").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn(
    "parentId", 
    when(
        col("A")===99, 
        last(
            when(col("A")=!=99, col("index")), 
            ignoreNulls=true
        ).over(w)
    )
).show
 --- ---- ----- -------- 
|  A|   B|index|parentId|
 --- ---- ----- -------- 
| 20|data|    0|    null|
| 21|data|    1|    null|
| 99|data|    2|       1|
| 99|data|    3|       1|
| 99|data|    4|       1|
| 25|data|    5|    null|
| 99|data|    6|       5|
| 99|data|    7|       5|
 --- ---- ----- -------- 

First we can mark the required indexes as candidates for parentId, let's pretend that we have such column and call it candidateParentId:

when(col("A")=!=99, col("index")).as("candidateParentId")

I.e., for rows that have A==99 - candidateParentId would be null, and for rows A!=99 - would contain index of that rows.

Then for each row we need to go backwards and take first non-null candidateParentId:

val w = Window.orderBy("index").rowsBetween(Window.unboundedPreceding, Window.currentRow)
last(when(col("A")=!=99, col("index")), ignoreNulls=true).over(w)

This would fill candidateParentId for every row, but we want only for rows A==99, so we add that one final check to get our parentId column:

when(
        col("A")===99, 
        last(
            when(col("A")=!=99, col("index")), 
            ignoreNulls=true
        ).over(w)
    )
  • Related