Home > Enterprise >  How can I save a single column of a pyspark dataframe in multiple json files?
How can I save a single column of a pyspark dataframe in multiple json files?

Time:07-31

I have a dataframe that looks a bit like this:

| key 1 | key 2 | key 3 | body |

I want to save this dataframe in 1 json-file per partition, where a partition is a unique combination of keys 1 to 3. I have the following requirements:

  • The paths of the files should be /key 1/key 2/key 3.json.gz
  • The files should be compressed
  • The contents of the files should be values of body (this column contains a json string), one json-string per line.

I've tried multiple things, but no look.

Method 1: Using native dataframe.write I've tried using the native write method to save the data. Something like this:

df.write
  .partitionBy("key 1", "key 2", "key 3") \
  .mode('overwrite') \
  .format('json') \
  .option("codec", "org.apache.hadoop.io.compress.GzipCodec") \
  .save(
    path=path,
    compression="gzip"
  )

This solution doesn't store the files in the correct path and with the correct name, but this can be fixed by moving them afterwards. However, the biggest problem is that this is writing the complete dataframe, while I only want to write the values of the body column. But I need the other columns to partition the data.

Method 2: Using the Hadoop filesystem It's possible to directly call the Hadoop filesystem java library using this: sc._gateway.jvm.org.apache.hadoop.fs.FileSystem. With access to this filesystem it's possible to create files myself, giving me more control over the path, the filename and the contents. However, in order to make this code scale I'm doing this per partition, so:

df.foreachPartition(save_partition)

def save_partition(items):
  # Store the items of this partition here

However, I can't get this to work because the save_partition function is executed on the workers, which doesn't have access to the SparkSession and the SparkContext (which is needed to reach the Hadoop Filesystem JVM libraries). I could solve this by pulling all the data to the driver using collect() and save it from there, but that won't scale.

So, quite a story, but I prefer to be complete here. What am I missing? Is it impossible to do what I want, or am I missing something obvious? Or is it difficult? Or maybe it's only possible from Scala/Java? I would love to get some help on this.

CodePudding user response:

It may be slightly tricky to do in pure pyspark. It is not recommended to create too many partitions. From what you have explained I think you are using partition only to get one JSON body per file. You may need a bit of Scala here but your spark job can still remain to be a PySpark Job.

Spark Internally defines DataSources interfaces through which you can define how to read and write data. JSON is one such data source. You can try to extend the default JsonFileFormat class and create your own JsonFileFormatV2. You will also need to define a JsonOutputWriterV2 class extending the default JsonOutputWriter. The output writer has a write function that gives you access to individual rows and paths passed on from the spark program. You can modify the write function to meet your needs.

Here is a sample of how I achieved customizing JSON writes for my use case of writing a fixed number of JSON entries per file. You can use it as a reference for implementing your own JSON writing strategy.

class JsonFileFormatV2 extends JsonFileFormat {
  override val shortName: String = "jsonV2"

  override def prepareWrite(
                             sparkSession: SparkSession,
                             job: Job,
                             options: Map[String, String],
                             dataSchema: StructType): OutputWriterFactory = {
    val conf = job.getConfiguration
    val fileLineCount = options.get("filelinecount").map(_.toInt).getOrElse(1)
    val parsedOptions = new JSONOptions(
      options,
      sparkSession.sessionState.conf.sessionLocalTimeZone,
      sparkSession.sessionState.conf.columnNameOfCorruptRecord)
    parsedOptions.compressionCodec.foreach { codec =>
      CompressionCodecs.setCodecConfiguration(conf, codec)
    }

    new OutputWriterFactory {
      override def newInstance(
                                path: String,
                                dataSchema: StructType,
                                context: TaskAttemptContext): OutputWriter = {
        new JsonOutputWriterV2(path, parsedOptions, dataSchema, context, fileLineCount)
      }

      override def getFileExtension(context: TaskAttemptContext): String = {
        ".json"   CodecStreams.getCompressionExtension(context)
      }
    }
  }

}


private[json] class JsonOutputWriterV2(
                                        path: String,
                                        options: JSONOptions,
                                        dataSchema: StructType,
                                        context: TaskAttemptContext,
                                        maxFileLineCount: Int) extends JsonOutputWriter(
  path,
  options,
  dataSchema,
  context) {

  private val encoding = options.encoding match {
    case Some(charsetName) => Charset.forName(charsetName)
    case None => StandardCharsets.UTF_8
  }
  var recordCounter = 0
  var filecounter = 0
  private val maxEntriesPerFile = maxFileLineCount

  private var writer = CodecStreams.createOutputStreamWriter(
    context, new Path(modifiedPath(path)), encoding)

  private[this] var gen = new JacksonGenerator(dataSchema, writer, options)

  private def modifiedPath(path:String): String = {
    val np = s"$path-filecount-$filecounter"
    np
  }

  override def write(row: InternalRow): Unit = {
    gen.write(row)
    gen.writeLineEnding()
    recordCounter  = 1
    if(recordCounter >= maxEntriesPerFile){
      gen.close()
      writer.close()
      filecounter =1
      recordCounter = 0
      writer = CodecStreams.createOutputStreamWriter(
        context, new Path(modifiedPath(path)), encoding)
      gen = new JacksonGenerator(dataSchema, writer, options)
    }
  }

  override def close(): Unit = {
    if(recordCounter<maxEntriesPerFile){
      gen.close()
      writer.close()
    }
  }
}

You can add this new custom data source jar to spark classpath and then in your pyspark you can invoke it as follows.

df.write.format("org.apache.spark.sql.execution.datasources.json.JsonFileFormatV2").option("filelinecount","5").mode("overwrite").save("path-to-save")

  • Related