Home > Software design >  PySpark split map column into Multiple based on starting value of the Key
PySpark split map column into Multiple based on starting value of the Key

Time:07-30

I have a dataframe which looks like this

ID               MPCol
1        [a1 -> 1, b1 -> 2, a12 -> 5, b23 ->2, c12 ->2]
2        [a2 -> 2, b3 -> 3, a15 -> 1, b45 ->1, c54 ->2]
3        [a17 -> 2, b15 -> 1, a88 -> 2, b90 ->8, c98 -> 5]

I want something like this

ID  MPCol1             MPCol2              MPCol3                     
1  [a1 ->1, a12 -> 5],[b1 -> 2, b23 -> 2],[c12 -> 2]
2  [a2 -> 2,a15 -> 1],[b3 -> 3, b45 -> 1],[c54 -> 2]
3  [a17 -> 2,a88 -> 2],[b15 ->1, b90 -> 8],[c98 -> 5]

I want to split the map based on starting letter of key, all keys with a in one and all keys with b in another and similarly with c

My approach


df.withColumn("MPCOL")
  .select($"MPCOL", explode($"A1"))
  .groupBy("MPCOL")
  .pivot("key")
  .agg(first("value")).show()

CodePudding user response:

Assuming you need Scala solution (from how your input looks like), you can use udf() to do the grouping:

val df = Seq(
  Map("a1" -> 1, "b1" -> 2, "a12" -> 5, "b23" -> 2, "c12" -> 2),
  Map("a2" -> 2, "b3" -> 3, "a15" -> 1, "b45" -> 1, "c54" -> 2),
  Map("a17" -> 2, "b15" -> 1, "a88" -> 2, "b90" -> 8, "c98" -> 5),
).toDF("MPCol")
  .select((monotonically_increasing_id()   lit(1)).as("ID"), $"*")

val group = udf((_: Map[String, Int]).groupBy(_._1.substring(0, 1)))
val keys = df.select(explode($"MPCol")).select($"key".substr(0, 1)).distinct().map(_.getString(0)).collect
val cols = keys.sorted.zipWithIndex.map(k => $"group".getItem(k._1).as(s"MPCol${k._2   1}")).prepended($"ID")

df.show(false)
df.withColumn("group", group($"MPCol")).select(cols: _*).show(false)

For pyspark, it's best to avoid using udf() since they can be slow:

df = spark.createDataFrame(
    [[{"a1": 1, "b1": 2, "a12": 5, "b23": 2, "c12": 2}],
     [{"a2": 2, "b3": 3, "a15": 1, "b45": 1, "c54": 2}],
     [{"a17": 2, "b15": 1, "a88": 2, "b90": 8, "c98": 5}],
     ], ["MPCol"]
).select((F.monotonically_increasing_id()   F.lit(1)).alias('ID'), "*")

df2 = df.select('ID', F.explode('MPCol'), F.col('key').substr(0, 1).alias('first_char'))

name_map = df2.select('first_char').distinct().sort('first_char').withColumn(
    'col_name', F.concat(F.lit('MPCol'), F.monotonically_increasing_id()   F.lit(1))
).toPandas()

df.show(truncate=False)
(
    df2.replace(name_map['first_char'].tolist(), name_map['col_name'].tolist())
    .groupby('ID', 'first_char')
    .agg(F.map_from_arrays(F.collect_list('key'), F.collect_list('value')).alias('collected'))
    .groupby('ID')
    .pivot('first_char')
    .agg(F.first('collected'))
    .show(truncate=False)
)
  • Related