Home > Blockchain >  Find array intersection for each row in Pyspark
Find array intersection for each row in Pyspark

Time:11-30

I have dataframe:

df_example = pd.DataFrame({'user1': ['u1', 'u1', 'u1', 'u5', 'u5', 'u5', 'u7','u7','u6'],
                           'user2': ['u2', 'u3', 'u4', 'u2',  'u4','u6','u8','u3','u6']})
sdf = spark.createDataFrame(df_example)
userreposts_gr = sdf.groupby('user1').agg(F.collect_list('user2').alias('all_user2'))
userreposts_gr.show()
 ----- ------------ 
|user1|   all_user2|
 ----- ------------ 
|   u1|[u4, u2, u3]|
|   u7|    [u8, u3]|
|   u5|[u4, u2, u6]|
|   u6|        [u6]|
 ----- ------------ 

I want for each user1 to see the intersections for all_user2.Create a new column that has the maximum intersection with the user1

 ----- ------------ ------------------------------ 
|user1|all_user2   |new_col                       |
 ----- ------------ ------------------------------ 
|u1   |[u2, u3, u4]|{max_count -> 2, user -> 'u5'}|
|u5   |[u2, u4, u6]|{max_count -> 2, user -> 'u1'}|
|u7   |[u8, u3]    |{max_count -> 1, user -> 'u1'}|
|u6   |[u6]        |{max_count -> 1, user -> 'u5'}|
 ----- ------------ ------------------------------ 

CodePudding user response:

Step 1: Cross join to find all the combination of user1 pair

Step 2: Find the length of the intersected array

Step 3: Rank by the length and select the largest value in each user

output = userreposts_gr\
    .selectExpr(
        'user1', 'all_user2 AS arr1'
    ).crossJoin(
        userreposts_gr.selectExpr('user1 AS user2', 'all_user2 AS arr2')
    ).withColumn(
        'intersection', func.size(func.array_intersect('arr1', 'arr2'))
    )

output = output\
    .filter(
        func.col('user1') != func.col('user2')
    ).withColumn(
        'ranking', func.rank().over(Window.partitionBy('user1').orderBy(func.desc('intersection')))
    )

output = output\
    .filter(
        func.col('ranking') == 1
    ).withColumn(
        'new_col', func.create_map(func.lit('max_count'), func.col('intersection'), func.lit('user'), func.col('user2'))
    )

output = output\
    .selectExpr(
        'user1', 'arr1 AS all_user2', 'new_col'
    )

output.show(100, False)
 ----- ------------ ---------------------------- 
|user1|all_user2   |new_col                     |
 ----- ------------ ---------------------------- 
|u1   |[u2, u3, u4]|{max_count -> 2, user -> u5}|
|u5   |[u2, u4, u6]|{max_count -> 2, user -> u1}|
|u6   |[u6]        |{max_count -> 1, user -> u5}|
|u7   |[u8, u3]    |{max_count -> 1, user -> u1}|
 ----- ------------ ---------------------------- 
  • Related