I have a dataframe
df = spark.createDataFrame(
[(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
[['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
(2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
[['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
(2022, 2, 4, '03', ['pomelo', 'fig'],
[['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
['year', 'month', 'day', 'id', "list_of_fruits",
'collected_tokens', 'most_common_word']
)
---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ --------------------------
|year|month|day|id |list_of_fruits |collected_tokens |most_common_word |
---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ --------------------------
|2022|1 |3 |01 |[apple, banana, orange] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [source, Vitamin C, fruit]] |[[fruit, 2], [Vitamin, 2]]|
|2022|1 |3 |02 |[apple, banana, avocado]|[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [medium, dark, green, fruit]]|[[fruit, 3], [green, 2]] |
|2022|2 |4 |03 |[pomelo, fig] |[[citrus, fruit, sweet], [soft, sweet]] |[[sweet, 2]] |
---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ --------------------------
I want to groupby by year, day, and month and intersect rows containing a list, a list of lists and a list with a key and min value (the last three columns respectively). In the end, I would like this result
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens |intersection_most_common_word|
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
|2022|1 |3 |01 |[apple, banana] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]] |
|2022|1 |3 |02 |[apple, banana] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]] |
|2022|2 |4 |03 |[pomelo, fig] |[[citrus, fruit, sweet], [soft, sweet]] |[[sweet, 2]] |
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
So in the column intersection_list_of_fruits
missing [orange],[avocado]
, in the column intersection_collected_tokens
missing [source, Vitamin C, fruit], [medium, dark, green, fruit]
and in the column intersection_most_common_word
missing [Vitamin, 2], [green, 2]
.
I know about array_intersect
, but I need to look at the intersection by row, and also need to use an aggregation function due to groupby - to group ids with the same date and intersect them. (I think this can be done using spark's applyInPandas function)
CodePudding user response:
You can use aggregate
and array_intersect
, along with collect_set
to compute the intersection on list_of_fruits
and collected_tokens
to obtain intersection_list_of_fruits
and intersection_collected_tokens
.
However, since intersection_most_common_word
needs to account for the count of the words. To do this,
- Find the intersections of words excluding counts
- Iterate over the intersection words and the collect arrays in
most_common_word
and find the minimum count
from pyspark.sql import functions as F
from pyspark.sql import Window as W
from pyspark.sql import Column
df = spark.createDataFrame(
[(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
[['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
(2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
[['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
(2022, 2, 4, '03', ['pomelo', 'fig'],
[['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
['year', 'month', 'day', 'id', "list_of_fruits",
'collected_tokens', 'most_common_word']
)
def intersection_expr(col_name: str, window_spec: W) -> Column:
lists = F.collect_set(col_name).over(window_spec)
return F.aggregate(lists, lists[0], lambda acc,x: F.array_intersect(acc, x))
def intersect_min(col_name: str, window_spec: W) -> Column:
# Convert array into map of word and count and collect into set
k = F.transform(F.col(col_name), lambda x: x[0])
v = F.transform(F.col(col_name), lambda x: x[1])
map_count = F.map_from_arrays(k, v)
map_counts = F.collect_list(map_count).over(window_spec)
# Find keys present in all list
keys = F.transform(map_counts, lambda x: F.map_keys(x))
intersected = F.aggregate(keys, keys[0], lambda acc,x: F.array_intersect(acc, x))
# For intersection find the minimum value
res = F.transform(intersected, lambda key: F.array(key, F.array_min(F.transform(map_counts, lambda m: m.getField(key)))))
return res
window_spec = W.partitionBy("year", "month", "day").orderBy("id").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
(df.select("year", "month", "day", "id",
intersection_expr("list_of_fruits", window_spec).alias("intersection_list_of_fruits"),
intersection_expr("collected_tokens", window_spec).alias("intersection_collected_tokens"),
intersect_min("most_common_word", window_spec).alias("intersection_most_common_word"))
.show(truncate=False))
"""
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens |intersection_most_common_word|
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
|2022|1 |3 |01 |[apple, banana] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]] |
|2022|1 |3 |02 |[apple, banana] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]] |
|2022|2 |4 |03 |[pomelo, fig] |[[citrus, fruit, sweet], [soft, sweet]] |[[sweet, 2]] |
---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ -----------------------------
"""