I have a spark data frame as below:
from pyspark.sql import SparkSession
from pyspark.sql import Window
import pyspark.sql.functions as F
data = [{"Category": 'A', "ID": 1, "Value": 121.44, "Truth": True, "time": 1},
{"Category": 'B', "ID": 2, "Value": 300.01, "Truth": False, "time": 2},
{"Category": 'C', "ID": 3, "Value": 10.99, "Truth": None, "time": 3},
{"Category": 'C', "ID": 4, "Value": 33.87, "Truth": True, "time": 4},
{"Category": 'D', "ID": 4, "Value": 33.87, "Truth": True, "time": 5},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 6},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 7},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 8}
]
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data)
w = Window.partitionBy(F.col("Category")).rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.filter(df["ID"] == 4).withColumn("val_sum", F.sum(F.col("Value")).over(w)).withColumn("max_time", F.max(F.col("time")).over(w)).show()
and I received the follow output:
However, my expected output is like below:
-------- --- ----- ----- ---- ------------------ --------
|Category| ID|Truth|Value|time| val_sum|max_time|
-------- --- ----- ----- ---- ------------------ --------
| C| 4| true|33.87| 4| 33.87| 4|
| D| 4| true|33.87| 5| 33.87| 5|
| E| 4| true|33.87| 6| 33.87| 6|
| E| 4| true|33.87| 7| 67.74| 7|
| E| 4| true|33.87| 8|101.60999999999999| 8|
-------- --- ----- ----- ---- ------------------ --------
Can anyone please assist me with this?
-------- --- ----- ----- ---- ------------------ --------
|Category| ID|Truth|Value|time| val_sum|max_time|
-------- --- ----- ----- ---- ------------------ --------
| C| 4| true|33.87| 4| 33.87| 4|
| D| 4| true|33.87| 5| 33.87| 5|
| E| 4| true|33.87| 6| 33.87| 8|
| E| 4| true|33.87| 7| 67.74| 8|
| E| 4| true|33.87| 8|101.60999999999999| 8|
-------- --- ----- ----- ---- ------------------ --------
Please do let me know if it is not clear so that I could provide more info.
CodePudding user response:
Define another window spec for max function with unboundedPreceding
and unboundedFollowing
instead of currentRow
Example:
from pyspark.sql import SparkSession
from pyspark.sql import Window
import pyspark.sql.functions as F
data = [{"Category": 'A', "ID": 1, "Value": 121.44, "Truth": True, "time": 1},
{"Category": 'B', "ID": 2, "Value": 300.01, "Truth": False, "time": 2},
{"Category": 'C', "ID": 3, "Value": 10.99, "Truth": None, "time": 3},
{"Category": 'C', "ID": 4, "Value": 33.87, "Truth": True, "time": 4},
{"Category": 'D', "ID": 4, "Value": 33.87, "Truth": True, "time": 5},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 6},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 7},
{"Category": 'E', "ID": 4, "Value": 33.87, "Truth": True, "time": 8}
]
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data)
w = Window.partitionBy(F.col("Category")).rowsBetween(Window.unboundedPreceding, Window.currentRow)
w1 = Window.partitionBy(F.col("Category")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.filter(df["ID"] == 4).withColumn("val_sum", F.sum(F.col("Value")).over(w)).withColumn("max_time", F.max(F.col("time")).over(w1)).show()
-------- --- ----- ----- ---- ------------------ --------
|Category| ID|Truth|Value|time| val_sum|max_time|
-------- --- ----- ----- ---- ------------------ --------
| E| 4| true|33.87| 6| 33.87| 8|
| E| 4| true|33.87| 7| 67.74| 8|
| E| 4| true|33.87| 8|101.60999999999999| 8|
| D| 4| true|33.87| 5| 33.87| 5|
| C| 4| true|33.87| 4| 33.87| 4|
-------- --- ----- ----- ---- ------------------ --------