Home > OS >  Pyspark merge 2 Array of Maps into 1 column with missing keys
Pyspark merge 2 Array of Maps into 1 column with missing keys

Time:11-01

I have the following dataset

test_table = spark.createDataFrame(
    [
        ("US", "CA", "S", "2022-10-01",100, 10, 1),
        ("US", "CA", "M", "2022-10-01",100, 15, 5),
        ("US", "CA", "L", "2022-10-01",100, 20, 10),
        ("US", "CA", "S", "2022-10-01",200, 10, 1),
        ("US", "CA", "M", "2022-10-01",200, 15, 5),
        ("US", "CA", "L", "2022-10-01",200, 20, 10),
        ("US", "CA", "S", "2022-10-02",100, 11, 1),
        ("US", "CA", "M", "2022-10-02",100, 13, 3),
        ("US", "CA", "L", "2022-10-02",100, 17, 7),
        ("US", "CA", "S", "2022-10-02",200, 11, 1),
        ("US", "CA", "M", "2022-10-02",200, 13, 3),
    ],
    schema=["country_code","state_code","size","dt","store_id","ttl_sold","ttl_returned"]
)

I then do some aggregations and end up with 2 columns (latest_payload, prev_payload). These 2 columns have the following datatype. enter image description here

w = Window.partitionBy("country_code", "state_code", "size", "store_id").orderBy("dt").rangeBetween(Window.unboundedPreceding,0)
w2 = Window.partitionBy("country_code", "state_code", "size").orderBy("dt")
df_w_cumulative_sum = (
    test_table
    .withColumn("cumulative_ttl_sold", F.sum("ttl_sold").over(w))
    .withColumn("cumulative_ttl_returned", F.sum("ttl_returned").over(w))
    .groupBy("dt","country_code", "state_code", "size")
    .agg(F.collect_list(F.create_map(F.col("store_id"), F.struct(F.col("cumulative_ttl_sold"), F.col("cumulative_ttl_returned")))).alias("latest_payload"))
    .withColumn("prev_payload", F.lag(F.col("latest_payload"), 1).over(w2))
    .where(F.col("dt") == "2022-10-02")
)
row dt country_code state_code size latest_payload prev_payload
1 2022-10-01 US CA L [{"100":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}},{"200":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}}] null
2 2022-10-01 US CA M [{"100":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}},{"200":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}}] null
3 2022-10-01 US CA S [{"100":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}},{"200":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}}] null
4 2022-10-02 US CA L [{"100":{"cumulative_ttl_sold":37,"cumulative_ttl_returned":17}}] [{"100":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}},{"200":{"cumulative_ttl_sold":20,"cumulative_ttl_returned":10}}]
5 2022-10-02 US CA M [{"100":{"cumulative_ttl_sold":28,"cumulative_ttl_returned":8}},{"200":{"cumulative_ttl_sold":28,"cumulative_ttl_returned":8}}] [{"100":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}},{"200":{"cumulative_ttl_sold":15,"cumulative_ttl_returned":5}}]
6 2022-10-02 US CA S [{"100":{"cumulative_ttl_sold":21,"cumulative_ttl_returned":2}},{"200":{"cumulative_ttl_sold":21,"cumulative_ttl_returned":2}}] [{"100":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}},{"200":{"cumulative_ttl_sold":10,"cumulative_ttl_returned":1}}]

Expected Output for row 4

{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17}, '200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}

Attempted Solution: Gives me the wrong values for each row

@F.udf(
    MapType(
        IntegerType(), 
        StructType([
            StructField("cumulative_ttl_sold", LongType(), False), 
            StructField("cumulative_ttl_sold", LongType(), False)
        ])
    )
)
def merge_payloads(lastest_payload, prev_payload):
    payload: Dict[int, Dict[str, int]] = {}
    if prev_payload is not None:
        for latest in lastest_payload:
            for k,v in latest.items():
                payload[k] = v
        for prev in prev_payload:
            for k,v in prev.items():
                if k not in payload.keys():
                    payload[k]=v
                else:
                    break
    else:
        for latest in lastest_payload:
            for k, v in latest.items():
                payload[k] = v
    return payload

CodePudding user response:

Give this all the correct decorators and such, and it'll do what you're looking for...

def merge_payloads(latest_payload, prev_payload):
    return dict(y for x in [*prev_payload, *latest_payload] for y in x.items())

latest = [{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17}}]
prev = [{'100': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}, 
        {'200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}]

print(merge_payloads(latest, prev))

# Output:

{'100': {'cumulative_ttl_sold': 37, 'cumulative_ttl_returned': 17}, 
 '200': {'cumulative_ttl_sold': 20, 'cumulative_ttl_returned': 10}}
  • Related