Home > Net >  Apache Spark Custom groupBy on Dataframe based on value count
Apache Spark Custom groupBy on Dataframe based on value count

Time:06-18

I have a sample data as below and I want to group based on value count. I have thresold count of value in each group. If more than 3 values are there then I need to have that as another row.. Can someone please advise.. I have normal groupby but how to transform the based on thresold count. I need to call an api which can accept max x amount of value in input so need to somehow break the value

val tdf1=tDf.groupBy(col("attr1")).agg(concat_ws(",",collect_list(col("attr2"))).as("value"))

Input:

scala> tdf1.show(false)
 ----- ----------------------- 
|attr1|value                  |
 ----- ----------------------- 
|1    |2,20                   |
|2    |200,201,202,203,204,205|
 ----- ----------------------- 

Expectation:

scala> tdf1.show(false)
     ----- ----------------------- 
    |attr1|value                  |
     ----- ----------------------- 
    |1    |2,20                   |
    |2    |200,201,202
    |2    |203,204,205            |

Input file:

    [{
  "attr1":"1",
  "attr2":"2"
},{
  "attr1":"1",
  "attr2":"20"
},{
  "attr1":"2",
  "attr2":"200"
},{
  "attr1":"2",
  "attr2":"201"
},{
  "attr1":"2",
  "attr2":"202"
},{
  "attr1":"2",
  "attr2":"203"
},{
  "attr1":"2",
  "attr2":"204"
},{
  "attr1":"2",
  "attr2":"205"
}]

CodePudding user response:

we can do it using collect_list() window function that selects only 3 records and slides.

First, create the list of 3 elements for every record window, along with a row number to keep the track of the record number. The record number will be used later for filtering.

data_sdf. \
    withColumn('rn', func.row_number().over(wd.partitionBy('attr1').orderBy('attr2')) - 1). \
    withColumn('attr3', func.concat_ws(',', func.collect_list('attr2').over(wd.partitionBy('attr1').rowsBetween(0, 2)))). \
    show()

#  ----- ----- --- ----------- 
# |attr1|attr2| rn|      attr3|
#  ----- ----- --- ----------- 
# |    1|    2|  0|       2,20|
# |    1|   20|  1|         20|
# |    2|  200|  0|200,201,202|
# |    2|  201|  1|201,202,203|
# |    2|  202|  2|202,203,204|
# |    2|  203|  3|203,204,205|
# |    2|  204|  4|    204,205|
# |    2|  205|  5|        205|
#  ----- ----- --- ----------- 

Once we get it for all records, we can just retain the required records that will give us the 3 elements i.e., every 3rd record after the first one in the group.

data_sdf. \
    withColumn('rn', func.row_number().over(wd.partitionBy('attr1').orderBy('attr2')) - 1). \
    withColumn('attr3', func.concat_ws(',', func.collect_list('attr2').over(wd.partitionBy('attr1').rowsBetween(0, 2)))). \
    filter((func.col('rn') == 0) | (func.col('rn') % 3 == 0)). \
    show()

#  ----- ----- --- ----------- 
# |attr1|attr2| rn|      attr3|
#  ----- ----- --- ----------- 
# |    1|    2|  0|       2,20|
# |    2|  200|  0|200,201,202|
# |    2|  203|  3|203,204,205|
#  ----- ----- --- ----------- 

we can get rid of the rn and attr2 columns in the aforementioned step after the filter.

  • Related