I have the following dataset
test_table = spark.createDataFrame(
[
("US", "CA", "S", "2022-10-01",100, 10, 1),
("US", "CA", "M", "2022-10-01",100, 15, 5),
("US", "CA", "L", "2022-10-01",100, 20, 10),
("US", "CA", "S", "2022-10-01",200, 10, 1),
("US", "CA", "M", "2022-10-01",200, 15, 5),
("US", "CA", "L", "2022-10-01",200, 20, 10),
("US", "CA", "S", "2022-10-02",100, 11, 1),
("US", "CA", "M", "2022-10-02",100, 13, 3),
("US", "CA", "L", "2022-10-02",100, 17, 7),
("US", "CA", "S", "2022-10-02",200, 11, 1),
("US", "CA", "M", "2022-10-02",200, 13, 3),
],
schema=["country_code","state_code","size","dt","store_id","ttl_sold","ttl_returned"]
)
I then do some aggregations and end up with 2 columns (latest_payload
, prev_payload
). These 2 columns have the following datatype.
w = Window.partitionBy("country_code", "state_code", "size", "store_id").orderBy("dt").rangeBetween(Window.unboundedPreceding,0)
w2 = Window.partitionBy("country_code", "state_code", "size").orderBy("dt")
df_w_cumulative_sum = (
test_table
.withColumn("cumulative_ttl_sold", F.sum("ttl_sold").over(w))
.withColumn("cumulative_ttl_returned", F.sum("ttl_returned").over(w))
.groupBy("dt","country_code", "state_code", "size")
.agg(F.collect_list(F.create_map(F.col("store_id"), F.struct(F.col("cumulative_ttl_sold"), F.col("cumulative_ttl_returned")))).alias("latest_payload"))
.withColumn("prev_payload", F.lag(F.col("latest_payload"), 1).over(w2))
.where(F.col("dt") == "2022-10-02")
)
row | dt | country_code | state_code | size | latest_payload | prev_payload |
---|---|---|---|---|---|---|
1 | 2022-10-01 | US | CA | L | [{"100":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}},{"200":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}}] | null |
2 | 2022-10-01 | US | CA | M | [{"100":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}},{"200":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}}] | null |
3 | 2022-10-01 | US | CA | S | [{"100":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}},{"200":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}}] | null |
4 | 2022-10-02 | US | CA | L | [{"100":{"cumulative_ttl_sold":37,"cumulative_ttl_returned":17}}] | [{"100":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}},{"200":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}}] |
5 | 2022-10-02 | US | CA | M | [{"100":{"cumulative_ttl_sold":28,"cumulative_ttl_returned":8}},{"200":{"cumulative_ttl_sold":28,"cumulative_ttl_returned":8}}] | [{"100":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}},{"200":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}}] |
6 | 2022-10-02 | US | CA | S | [{"100":{"cumulative_ttl_sold":21,"cumulative_ttl_returned":2}},{"200":{"cumulative_ttl_sold":21,"cumulative_ttl_returned":2}}] | [{"100":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}},{"200":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}}] |
Expected Output for row 4
{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17}, '200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}
Attempted Solution: Gives me the wrong values for each row
@F.udf(
MapType(
IntegerType(),
StructType([
StructField("cumulative_ttl_sold", LongType(), False),
StructField("cumulative_ttl_sold", LongType(), False)
])
)
)
def merge_payloads(lastest_payload, prev_payload):
payload: Dict[int, Dict[str, int]] = {}
if prev_payload is not None:
for latest in lastest_payload:
for k,v in latest.items():
payload[k] = v
for prev in prev_payload:
for k,v in prev.items():
if k not in payload.keys():
payload[k]=v
else:
break
else:
for latest in lastest_payload:
for k, v in latest.items():
payload[k] = v
return payload
CodePudding user response:
Give this all the correct decorators and such, and it'll do what you're looking for...
def merge_payloads(latest_payload, prev_payload):
return dict(y for x in [*prev_payload, *latest_payload] for y in x.items())
latest = [{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17}}]
prev = [{'100': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}},
{'200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}]
print(merge_payloads(latest, prev))
# Output:
{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17},
'200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}