Home > Software engineering >  simplify multiple (30 columns) column complex pyspark aggregation in one go
simplify multiple (30 columns) column complex pyspark aggregation in one go

Time:04-15

I have a sample spark df as below:

df = ([[1, 'a', 'b' , 'c'],
      [1, 'b', 'c' , 'b'],
      [1, 'b', 'a' , 'b'],
      [2, 'c', 'a' , 'a'],
      [3, 'b', 'b' , 'a']]).toDF(['id', 'field1', 'field2', 'field3'])

What I need next is to provide a multiple aggregations to show summary of the a, b, c values for each field. I have a working but tedious process as below:

agg_table = (
  df
        .groupBy('id') 
        .agg(
          # field1
             sum(when(col('field1') == 'a',1).otherwise(0)).alias('field1_a_count')
             ,sum(when(col('field1') == 'b',1).otherwise(0)).alias('field1_b_count')
             ,sum(when(col('field1') == 'c',1).otherwise(0)).alias('field1_c_count')
          # field2
             ,sum(when(col('field2') == 'a',1).otherwise(0)).alias('field2_a_count')
             ,sum(when(col('field2') == 'b',1).otherwise(0)).alias('field2_b_count')
             ,sum(when(col('field2') == 'c',1).otherwise(0)).alias('field2_c_count')
          # field3
             ,sum(when(col('field3') == 'a',1).otherwise(0)).alias('field3_a_count')
             ,sum(when(col('field3') == 'b',1).otherwise(0)).alias('field3_b_count')
             ,sum(when(col('field3') == 'c',1).otherwise(0)).alias('field3_c_count')
         ))

What I am expecting to get is this:

agg_table = (['id':'1','2','3'],
             ['field1_a_count':1,0,0],
             ['field1_b_count':2,0,1],
             ['field1_c_count':0, 1, 0],
             ['field2_a_count':1,1,0],
             ['field2_b_count':1,0,1],
             ['field2_c_count':1,0,0],
             ['field3_a_count':0,1,1],
             ['field3_b_count':2,0,0],
             ['field3_c_count':1,0,0])

It is just fine if I only really have 3 fields, but I have 30 fields with varying/custom names. Maybe somebody can help me with the repetitive task of coding the aggregated sum per field. I tried playing around with a suggestion from : https://danvatterott.com/blog/2018/09/06/python-aggregate-udfs-in-pyspark/

I can make it work if I will only pull one column and one value, but I get varying errors, one of them is:

AnalysisException: cannot resolve '`value`' given input columns: ['field1','field2','field3']

One last line I tried is using:

validated_cols = ['field1','field2','field3']

df.select(validated_cols).groupBy('id').agg(collect_list($'field1_a_count',$'field1_b_count',$'field1_c_count', ...
 
 $'field30_c_count')).show()

Output: SyntaxError: invalid syntax

I tried with pivot too, but from searches so far, it says it is only good for one column. I tried this multiple columns:

df.withColumn("p", concat($"p1", $"p2"))
  .groupBy("a", "b")
  .pivot("p")
  .agg(...)

I still get a syntax error.

Another link I tried: https://danvatterott.com/blog/2019/02/05/complex-aggregations-in-pyspark/

I also tried the exprs approach: exprs1 = {x: "sum" for x in df.columns if x != 'id'}

Any suggested will be appreciated. Thanks

CodePudding user response:

Let me answer your question in two steps. First, you are wondering if it is possible to avoid hard coding all your aggregations in your attempt to compute all your aggregations. It is. I would do it like this:

from pyspark.sql import functions as f

# let's assume that this is known, but we could compute it as well
values = ['a', 'b', 'c']
# All the columns except the id
cols = [ c for c in df.columns if c != 'id' ]

def count_values(column, value):
    return f.sum(f.when(f.col(column) == value, 1).otherwise(0))\
            .alias(f"{column}_{value}_count")

# And this gives you the result of your hard coded aggregations:
df\
    .groupBy('id')\
    .agg(*[count_values(c, value) for c in cols for value in values])\
    .show()

But that is not what you expect right? You are trying to compute some kind of pivot on the id column. To do this, I would not use the previous result, but just work the data differently. I would start by replacing all the columns of the dataframe but id (that is renamed into x) by an array of values of the form {column_name}_{value}_count, and I would explode that array. From there, we just need to compute a simple pivot on the former id column renamed x, grouped by the values contained in the exploded array.

df\
    .select(f.col('id').alias('x'), f.explode(
         f.array(
             [f.concat_ws('_', f.lit(c), f.col(c), f.lit('count')).alias(c)
                 for c in cols]
         )
    ).alias('id'))\
    .groupBy('id')\
    .pivot('x')\
    .count()\
    .na.fill(0)\
    .orderBy('id')\
    .show()

which yields:

 -------------- --- --- --- 
|            id|  1|  2|  3|
 -------------- --- --- --- 
|field1_a_count|  1|  0|  0|
|field1_b_count|  2|  0|  1|
|field1_c_count|  0|  1|  0|
|field2_a_count|  1|  1|  0|
|field2_b_count|  1|  0|  1|
|field2_c_count|  1|  0|  0|
|field3_a_count|  0|  1|  1|
|field3_b_count|  2|  0|  0|
|field3_c_count|  1|  0|  0|
 -------------- --- --- --- 

CodePudding user response:

This is a sign of a design flaw in the data. Whatever the "field1", "field2", etc... columns actually represent, it appears they are all related, in the sense that the values quantify some attribute (maybe each one is a count for a specific merchandise ID, or the number of people with a certain property...). The problem is that these fields are being added as individual columns on a fact table1, which then needs to be aggregated, resulting in the situation that you're facing.

A better design would be to collapse those "field1", "field2", etc... columns into a single code field that can be used as the GROUP BY field when doing the aggregation. You might want to consider creating a separate table to do this if the existing one has many other columns and making this change would alter the grain in a way that might cause other problems.


1: it's usually a big red flag to have a table with a bunch of enumerated columns with the same name and purpose. I've even seen cases where someone has created tables with "spare" columns for when they want to add more attributes later. Not good.

  • Related