I've got a dataframe of roles
and the ids
of people who play those roles. In the table below, the roles are a,b,c,d
and the people are a3,36,79,38
.
What I want is a map of people to an array of their roles, as shown to the right of the table.
--- ---- ---- --- --- --------
|rec| a| b| c| d| ppl | pplmap
--- ---- ---- --- --- -------- -------------------------------------
| D| a3| 36| 36| 36|[a3, 36]| [ a3 -> ['a'], 36 -> ['b','c','d'] ]
| E| a3| 79| 79| a3|[a3, 79]| [ a3 -> ['a','d'], 79 -> ['b','c'] ]
| F|null|null| 38| 38| [38]| [ 38 -> ['c','d'] ]
--- ---- ---- --- --- --------
And, actually, what I really want is a nicely readable report, like:
D
a3 roles: a
36 roles: b, c, d
E
a3 roles: a, d
79 roles: b, c
F
38 roles: c, d
I'm using PySpark 3.
Any suggestions? Thank you!!
CodePudding user response:
You can, first, unpivot the dataframe then using some groupby you can construct the map column you want.
Input dataframe:
data = [
("D", "a3", "36", "36", "36", ["a3", "36"]),
("E", "a3", "79", "79", "a3", ["a3", "79"]),
("F", None, None, "38", "38", ["38"]),
]
df = spark.createDataFrame(data, ["id", "a", "b", "c", "d", "ppl"])
Using stack
function to unpivot and map_from_entries
after grouping:
import pyspark.sql.functions as F
df1 = df.selectExpr(
"id",
"stack(4, 'a', a, 'b', b, 'c', c, 'd', d) as (role, person)"
).filter(
"person is not null"
).groupBy("id", "person").agg(
F.collect_list("role").alias("roles")
).groupBy("id").agg(
F.map_from_entries(
F.collect_list(F.struct(F.col("person"), F.col("roles")))
).alias("pplmap")
)
df1.show(truncate=False)
# --- ----------------------------
#|id |pplmap |
# --- ----------------------------
#|F |{38 -> [c, d]} |
#|E |{79 -> [b, c], a3 -> [a, d]}|
#|D |{a3 -> [a], 36 -> [b, c, d]}|
# --- ----------------------------
If you want to dynamically generate the stack expression (in case you have many role columns), you can see my other answer here.
CodePudding user response:
Set-up:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
df = pd.DataFrame({
'rec': list('DEF'),
'a': ['a3', 'a3', None],
'b': [36, 79, None],
'c': [36, 79, 38],
'd': [36, 55, 38]
})
spark = SparkSession.builder \
.master("local[1]") \
.appName("SparkByExamples.com") \
.getOrCreate()
df = spark.createDataFrame(df)
Then melt the DataFrame accordingly, group by values and aggregate by keys:
cols_to_melt = list('abcd')
res = df.withColumn(
"tmp",
explode(array(
[struct(lit(c).alias('key'), col(c).alias('val'))
for c in cols_to_melt]))) \
.select('rec', col('tmp.key'), col('tmp.val')) \
.dropna() \
.groupby(['rec', 'val']) \
.agg(collect_list('key').alias('keys')) \
.groupby('rec') \
.agg(map_from_entries(collect_list(struct("val","keys"))).alias('maps'))
res.show(truncate=False)
Output:
--- ----------------------------------------------
|rec|maps |
--- ----------------------------------------------
|F |{38 -> [c, d], NaN -> [b]} |
|E |{79 -> [c], 79.0 -> [b], a3 -> [a], 55 -> [d]}|
|D |{36.0 -> [b], a3 -> [a], 36 -> [c, d]} |
--- ----------------------------------------------
To get your report you just need to iterated through the collected data:
for row in res.collect():
print(row.rec)
print('\n'.join(f" {k} roles: {', '.join(v)}" for k, v in row.maps.items()))
Then your final report should look like:
F
38 roles: c, d
NaN roles: b
E
55 roles: d
79 roles: c
a3 roles: a
79.0 roles: b
D
36.0 roles: b
a3 roles: a
36 roles: c, d
One issue that I did not deal with here is that one of your columns contains both numeric and string values which is not possible in spark.
If you are converting a pandas DataFrame to a spark DataFrame (like I do in my example) you should pass an explicit schema.
If you are reading from CSV files you might not have to - the type will be automatically inferred as String
.
However, in that case, in order to group the columns where some have values like 38
and others "38"
you should make sure all relevant numeric columns are also converted to String
.
So, in any case it is better to use a schema to ensure you get exactly the types that you need in your DataFrame.