Home > Blockchain >  pyspark: how to group N records in a spark dataframe
pyspark: how to group N records in a spark dataframe

Time:10-01

I have a CSV with 5 million records, with the structure:

 ---------- ------------ ------------ 
|  row_id  |    col1    |    col2    |
 ---------- ------------ ------------ 
|         1|   value    |    value   |
|         2|   value    |    value   |
|....                                |
|...                                 |
|   5000000|   value    |    value   |
 ---------- ------------ ------------ 

I need to convert this CSV to JSON with each json-file having 500 records and a particular structure like this:

{
    "entry": [
        {
            "row_id": "1",
            "col1": "value",
            "col2": "value"
        },
        {
            "row_id": "2",
            "col1": "value",
            "col2": "value"
        },
        ....
        ..
        {
            "row_id": "500",
            "col1": "value",
            "col2": "value"
        }
    ],
    "last_updated":"09-09-2021T01:03:04.44Z"
}

Using PySpark I am able to read the csv and create a dataframe. I don't know how to group 500 records in a single json of the structure "entry": [ <500 records> ],"last_updated":"09-09-2021T01:03:04.44Z"
I can use df.coalesce(1).write.option("maxRecordsPerFile",500) but that will give me only the set of 500 records, without any structure. I want those 500 records in the "entry" list and "last_updated" following it (which I am taking from datetime.now()).

CodePudding user response:

You may try the following:


NB. I've used the following imports.

from pyspark.sql import functions as F
from pyspark.sql import Window

1. We need a column that can be used to split your data in 500 record batches

(Recommended) We can create a pseudo column to achieve this with row_number

df = df.withColumn("group_num",(F.row_number().over(Window.orderBy("row_id"))-1) % 500 )

otherwise, if row_id starting at 1 is consistently increasing in the 5 million records, we may use that

df = df.withColumn("group_num",(F.col("row_id")-1) % 500 )

or in that odd chance that the column "last_updated":"09-09-2021T01:03:04.44Z" is unique to each batch of 500 records

df = df.withColumn("group_num",F.col("last_updated"))

2. We will transform your dataset by grouping by the group_num

df = (
    df.groupBy("group_num")
      .agg(
          F.collect_list(
              F.expr("struct(row_id,col1,col2)")
          ).alias("entries")
      )
      .withColumn("last_updated",F.lit(datetime.now())))
      .drop("group_num")
)

NB. If you would like to include all columns you may use F.expr("struct(*)") instead of F.expr("struct(row_id,col1,col2)").


3. Finally you can write to your output/destination with the option .option("maxRecordsPerFile",1) since each row now stores at most 500 entries

Eg.

df.write.format("json").option("maxRecordsPerFile",1).save("<your intended path here>")

Let me know if this works for you

  • Related