Home > Mobile >  How to assign non unique incrementing index (index markup) in Spark SQL, set back to 0 on joining th
How to assign non unique incrementing index (index markup) in Spark SQL, set back to 0 on joining th

Time:12-18

There is a DataFrame of data like

|timestamp          |value|
|2021-01-01 12:00:00| 10.0|
|2021-01-01 12:00:01| 10.0|
|2021-01-01 12:00:02| 10.0|
|2021-01-01 12:00:03| 10.0|
|2021-01-01 12:00:04| 10.0|
|2021-01-01 12:00:05| 10.0|
|2021-01-01 12:00:06| 10.0|
|2021-01-01 12:00:07| 10.0|

and DataFrame of events like

|timestamp          |event|
|2021-01-01 12:00:01| true|
|2021-01-01 12:00:05| true|

based on that I'd like to add one more column to the initial DataFrame that is an index of the data since beginning of the event:

|timestamp          |value|index|
|2021-01-01 12:00:00| 10.0|    1|
|2021-01-01 12:00:01| 10.0|    2|
|2021-01-01 12:00:02| 10.0|    3|
|2021-01-01 12:00:03| 10.0|    4|
|2021-01-01 12:00:04| 10.0|    5|
|2021-01-01 12:00:05| 10.0|    1|
|2021-01-01 12:00:06| 10.0|    2|
|2021-01-01 12:00:07| 10.0|    3|

I have tried with

.withColumn("index",monotonically_increasing_id())

but there is no way to set it back to 0 at joining it with some other DataFrame. So, any ideas are welcome.

CodePudding user response:

You can join data df with event df on timestamp then use a conditional cumulative sum on event column to define groups. Finally, partition by the group column to set row number.

Something like this:

import org.apache.spark.sql.expressions.Window

val result = data.join(
    events, 
    Seq("timestamp"), 
    "left"
).withColumn(
    "group",
    sum(when(col("event"), 1).otherwise(0)).over(Window.orderBy("timestamp"))
).withColumn(
    "index",
    row_number().over(Window.partitionBy("group").orderBy("timestamp"))
).drop("group", "event")

result.show
// ------------------- ----- ----- 
//|          timestamp|value|index|
// ------------------- ----- ----- 
//|2021-01-01 12:00:00| 10.0|    1|
//|2021-01-01 12:00:01| 10.0|    1|
//|2021-01-01 12:00:02| 10.0|    2|
//|2021-01-01 12:00:03| 10.0|    3|
//|2021-01-01 12:00:04| 10.0|    4|
//|2021-01-01 12:00:05| 10.0|    1|
//|2021-01-01 12:00:06| 10.0|    2|
//|2021-01-01 12:00:07| 10.0|    3|
// ------------------- ----- ----- 

CodePudding user response:

You could use a Window function to achieve it:

from pyspark.sql import SparkSessionRow, Window
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

Example data after joining the original DFs (I changed the timestamp column to integer type for simplicity):

df = spark.createDataFrame([
    Row(timestamp=0, value='foo', event=True),
    Row(timestamp=1, value='foo', event=None),
    Row(timestamp=2, value='foo', event=None),
    Row(timestamp=3, value='foo', event=None),
    Row(timestamp=4, value='foo', event=None),
    Row(timestamp=5, value='foo', event=True),
    Row(timestamp=6, value='foo', event=None),
    Row(timestamp=7, value='foo', event=None),
])

Then I create a column with a group_id by forward-filling the first timestamp for the "groups". This group_id can then be used to create the index using F.row_number():

(
    df
    .withColumn('group_id', F.when(F.col('event'), F.col('timestamp')))
    .withColumn('group_id', F.last('group_id', ignorenulls=True).over(Window.orderBy('timestamp')))
    .withColumn('index', F.row_number().over(Window.partitionBy('group_id').orderBy('timestamp')))
    .show()
)

# Output:
 --------- ----- ----- -------- ----- 
|timestamp|value|event|group_id|index|
 --------- ----- ----- -------- ----- 
|        0|  foo| true|       0|    1|
|        1|  foo| null|       0|    2|
|        2|  foo| null|       0|    3|
|        3|  foo| null|       0|    4|
|        4|  foo| null|       0|    5|
|        5|  foo| true|       5|    1|
|        6|  foo| null|       5|    2|
|        7|  foo| null|       5|    3|
 --------- ----- ----- -------- ----- 
  • Related