Home > Mobile >  Extract value from ArrayType column in Scala and reshape to long
Extract value from ArrayType column in Scala and reshape to long

Time:05-15

I have a DataFrame that consists of Column that is ArrayType, and the array may have a different length in each row of the data. I have provide some example code below that can create some mock data with a similar structure.

You will see that for one transaction, I have a transaction ID as well as some additional data that are each stored in a "segment". Here we see one segment in which customer information is stored (always an array of length two) and we have an additional segment for each item purchase. The information about the purchased item itself is an array of varying length; the first two elements of the array will always be the ID and a name of purchased item; additional array elements may exist for color, etc. but we can ignore them in this use case.

val dfschema = new StructType()
  .add("transaction",
    new StructType()
      .add(
        "transaction_id",
        StringType
      )
      .add(
        "segments",
        ArrayType(
          new StructType()
            .add("segment_id",StringType)
            .add("segment_fields",ArrayType(
              StringType,
              false
            )
          ), false
        )
      )
    )


val mockdata = Seq(
  Row(
    Row(
      "2e6d57769e49ae8cb0c4105548c4389d",
      List(
        Row(
          "CustomerInformation",
          List(
            "SomeCustomerName",
            "SomeCustomerEmail"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "SomeItemID",
            "SomeItemName"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "AnotherItemID",
            "AnotherItemName",
            "ItemColor"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "YetAnotherItemID",
            "YetAnotherItemName",
            "ItemColor"
          )
        )
      )
    )
  )
)

val df = spark.createDataFrame(
  spark.sparkContext.parallelize(mockdata),
  dfschema)

What I want to accomplish is to convert the above into another dataframe with two columns, one for customer's name and one for item name. For the above example, it would like:

customer.name item.name
SomeCustomerName SomeItemName
SomeCustomerName AnotherItemName
SomeCustomerName YetAnotherItemName

However, I don't want to hardcode the fields of data that I am retreiving; instead, I want to write a couple functions that I you could run as part of a select command, like this:

df(
  select(
    get_single_subsegment("CustomerInformation", 0),
    get_repeated_subsements("ItemPurchased", 1)
  )
)

This way, if I choose to retrieve customer email instead of name, I just need to modify change 0 to 1 in the above. And I can even pass the index number as a variable.

Can this be done?

CodePudding user response:

As from Spark 3.0, you can use spark's built-in functions to define your two functions get_single_subsegment and get_repeated_subsegments

For get_single_subsegment, you can first filter your segments array by segment_id with filter, then get the first element of this filtered array with getItem, and then retrieve the element at desired index in this segment object using getField and getItem:

import org.apache.spark.sql.functions.{col, filter}

def get_single_subsegment(segmentId: String, index: Int): Column = {
  filter(col("transaction.segments"), c => c.getField("segment_id") === segmentId)
    .getItem(0)
    .getField("segment_fields")
    .getItem(index)
}

For get_repeated_subsegments, you first filter as in get_single_subsegment, but then use transform to extract right segment fields index for each elements of filtered array, and then explode this array in order to have one line by element of filtered array:

import org.apache.spark.sql.functions.{col, explode, filter, transform}

def get_repeated_subsegments(segmentId: String, index: Int): Column = {
  explode(
    transform(
      filter(col("transaction.segments"), c => c.getField("segment_id") === segmentId)
        .getField("segment_fields"),
      c => c.getItem(index)
    )
  )
}

If we apply the two function defined above on your example, we get the following result:

df.select(
  get_single_subsegment("CustomerInformation", 0).as("customer_name"),
  get_repeated_subsegments("ItemPurchased", 1).as("item_name")
).show(false)

// ---------------- ------------------ 
//|customer_name   |item_name         |
// ---------------- ------------------ 
//|SomeCustomerName|SomeItemName      |
//|SomeCustomerName|AnotherItemName   |
//|SomeCustomerName|YetAnotherItemName|
// ---------------- ------------------ 
  • Related