Home > Back-end >  How to create dictionary from hierarchical relations from spark dataframe
How to create dictionary from hierarchical relations from spark dataframe

Time:07-25

I have following spark data frame

schema = 'EMPLOYEE_NUMBER int, MANAGER_EMPLOYEE_NUMBER int'
employees = spark.createDataFrame(
[[801,None], 
[1016,801], 
[1003,801], 
[1019,801], 
[1010,1003], 
[1004,1003], 
[1001,1003],
[1012,1004], 
[1002,1004], 
[1015,1004], 
[1008,1019], 
[1006,1019], 
[1014,1019],
[1011,1019]], schema=schema)

I want to create dictionary from above data frame like {801:[1003,1019,1016], 1019:[1014,1011,1008,1006], 1003:[1010,1001,1004]} can I build dictionary like this from data frame

CodePudding user response:

datas = employees.groupBy('MANAGER_EMPLOYEE_NUMBER').agg(collect_set(col('EMPLOYEE_NUMBER')).alias('values')).collect()

ur_dict = {}
for item in datas:
    ur_dict[item['MANAGER_EMPLOYEE_NUMBER']] = item['values']

print(ur_dict)
# {1019: [1014, 1011, 1008, 1006], None: [801], 1003: [1004, 1001, 1010], 801: [1003, 1019, 1016], 1004: [1002, 1015, 1012]}

CodePudding user response:

You can use collect_list in group all the employees under MANAGER_EMPLOYEE_NUMBER & then use collect & asDict in conjunction with map to transform the resultant into a dictionary

Data Preparation

schema = 'EMPLOYEE_NUMBER int, MANAGER_EMPLOYEE_NUMBER int'

employees = sql.createDataFrame(
[[801,None], 
[1016,801], 
[1003,801], 
[1019,801], 
[1010,1003], 
[1004,1003], 
[1001,1003],
[1012,1004], 
[1002,1004], 
[1015,1004], 
[1008,1019], 
[1006,1019], 
[1014,1019],
[1011,1019]], schema=schema)

Collect List

employees_agg = employees.groupBy('MANAGER_EMPLOYEE_NUMBER')\
                         .agg(F.collect_list(F.col('EMPLOYEE_NUMBER')).alias('EMPLOYEES'))\
                         .filter(F.col('MANAGER_EMPLOYEE_NUMBER').isNotNull())


employees_agg.show()

 ----------------------- -------------------- 
|MANAGER_EMPLOYEE_NUMBER|           EMPLOYEES|
 ----------------------- -------------------- 
|                   1019|[1008, 1006, 1014...|
|                   1003|  [1010, 1004, 1001]|
|                    801|  [1016, 1003, 1019]|
|                   1004|  [1012, 1002, 1015]|
 ----------------------- -------------------- 

Transform

final_dict = {
        row['MANAGER_EMPLOYEE_NUMBER']: row['EMPLOYEES'] 
            for row in  map(lambda row: row.asDict(), employees_agg.collect())
}


pprint(final_dict)

{
 801: [1016, 1003, 1019],
 1003: [1010, 1004, 1001],
 1004: [1012, 1002, 1015],
 1019: [1008, 1006, 1014, 1011]
}

  • Related