Home > Net >  How to intersect rows containing an array for a dataframe in pyspark
How to intersect rows containing an array for a dataframe in pyspark

Time:03-17

I have a dataframe

   df = spark.createDataFrame(
    [(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
      [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
       ['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
     (2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
     [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
      ['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
     (2022, 2, 4, '03', ['pomelo', 'fig'],
     [['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
    ['year', 'month', 'day', 'id', "list_of_fruits",
        'collected_tokens', 'most_common_word']
)

 ---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ -------------------------- 
|year|month|day|id |list_of_fruits          |collected_tokens                                                                                                        |most_common_word          |
 ---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ -------------------------- 
|2022|1    |3  |01 |[apple, banana, orange] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [source, Vitamin C, fruit]]  |[[fruit, 2], [Vitamin, 2]]|
|2022|1    |3  |02 |[apple, banana, avocado]|[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [medium, dark, green, fruit]]|[[fruit, 3], [green, 2]]  |
|2022|2    |4  |03 |[pomelo, fig]           |[[citrus, fruit, sweet], [soft, sweet]]                                                                                 |[[sweet, 2]]              |
 ---- ----- --- --- ------------------------ ------------------------------------------------------------------------------------------------------------------------ --------------------------

I want to groupby by year, day, and month and intersect rows containing a list, a list of lists and a list with a key and min value (the last three columns respectively). In the end, I would like this result

 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens                                                             |intersection_most_common_word|
 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 
|2022|1    |3  |01 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|1    |3  |02 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|2    |4  |03 |[pomelo, fig]              |[[citrus, fruit, sweet], [soft, sweet]]                                                   |[[sweet, 2]]                 |
 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 

So in the column intersection_list_of_fruits missing [orange],[avocado], in the column intersection_collected_tokens missing [source, Vitamin C, fruit], [medium, dark, green, fruit] and in the column intersection_most_common_word missing [Vitamin, 2], [green, 2].

I know about array_intersect, but I need to look at the intersection by row, and also need to use an aggregation function due to groupby - to group ids with the same date and intersect them. (I think this can be done using spark's applyInPandas function)

CodePudding user response:

You can use aggregate and array_intersect, along with collect_set to compute the intersection on list_of_fruits and collected_tokens to obtain intersection_list_of_fruits and intersection_collected_tokens.

However, since intersection_most_common_word needs to account for the count of the words. To do this,

  1. Find the intersections of words excluding counts
  2. Iterate over the intersection words and the collect arrays in most_common_word and find the minimum count
from pyspark.sql import functions as F
from pyspark.sql import Window as W
from pyspark.sql import Column

df = spark.createDataFrame(
    [(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
      [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
       ['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
     (2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
     [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
      ['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
     (2022, 2, 4, '03', ['pomelo', 'fig'],
     [['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
    ['year', 'month', 'day', 'id', "list_of_fruits",
        'collected_tokens', 'most_common_word']
)

def intersection_expr(col_name: str, window_spec: W) -> Column:
    lists = F.collect_set(col_name).over(window_spec)
    return F.aggregate(lists, lists[0], lambda acc,x: F.array_intersect(acc, x))



def intersect_min(col_name: str, window_spec: W) -> Column:
    # Convert array into map of word and count and collect into set
    k = F.transform(F.col(col_name), lambda x: x[0])
    v = F.transform(F.col(col_name), lambda x: x[1])
    map_count = F.map_from_arrays(k, v)
    map_counts = F.collect_list(map_count).over(window_spec)
    
    # Find keys present in all list
    keys = F.transform(map_counts, lambda x: F.map_keys(x))
    intersected = F.aggregate(keys, keys[0], lambda acc,x: F.array_intersect(acc, x))
    
    # For intersection find the minimum value
    res = F.transform(intersected, lambda key: F.array(key, F.array_min(F.transform(map_counts, lambda m: m.getField(key)))))
    
    return res

window_spec = W.partitionBy("year", "month", "day").orderBy("id").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

(df.select("year", "month", "day", "id",
        intersection_expr("list_of_fruits", window_spec).alias("intersection_list_of_fruits"), 
        intersection_expr("collected_tokens", window_spec).alias("intersection_collected_tokens"),
        intersect_min("most_common_word", window_spec).alias("intersection_most_common_word"))
    .show(truncate=False))


"""
 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens                                                             |intersection_most_common_word|
 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 
|2022|1    |3  |01 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|1    |3  |02 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|2    |4  |03 |[pomelo, fig]              |[[citrus, fruit, sweet], [soft, sweet]]                                                   |[[sweet, 2]]                 |
 ---- ----- --- --- --------------------------- ------------------------------------------------------------------------------------------ ----------------------------- 
"""
  • Related