Home > database >  Checking if list of strings in a Scala Dataframe column is present in the value of a Map
Checking if list of strings in a Scala Dataframe column is present in the value of a Map

Time:02-24

I have the following data:

val df = Seq(
    (1, List("A")),
    (2, List("A")), 
    (3, List("A", "B")),
    (4, List("C")),
    (5, List("A")),
    (6, List("A", "C")),
    (7, List("B")),
    (8, List("A", "B", "C")),
    (9, List("A"))
  ).toDF("Serial Number", "my_list")

 -------------------- -------------------- 
|       Serial Number|             my_list|
 -------------------- -------------------- 
|                   1|                 [A]|
|                   2|                 [A]|
|                   3|               [A,B]|
|                   4|                 [C]|
|                   5|                 [A]|
|                   6|              [A, C]|
|                   7|                 [B]|
|                   8|           [A, B, C]|
|                   9|                 [A]|
 -------------------- -------------------- 

I have a map

val category_Mapping = Map("Category1" -> [A, B], 
                  "Category2" -> [C],
                  "Category3" -> [B, D])

I want to look for each list element in data["my_list"] and return an output map for each data["Serial Number"], in the following way:

 -------------------- -------------------- ------------------------------------------ 
|       Serial Number|             my_list|                                   output |
 -------------------- -------------------- ------------------------------------------ 
|                   1|                 [A]|{Category1->1, Category2->0, Category3->0}|
|                   2|                 [A]|{Category1->1, Category2->0, Category3->0}|
|                   3|               [A,B]|{Category1->1, Category2->0, Category3->1}|
|                   4|                 [C]|{Category1->0, Category2->1, Category3->0}|
|                   5|                 [A]|{Category1->1, Category2->0, Category3->0}|
|                   6|              [A, C]|{Category1->1, Category2->1, Category3->0}|
|                   7|                 [B]|{Category1->1, Category2->0, Category3->1}|
|                   8|           [A, B, C]|{Category1->1, Category2->1, Category3->1}|
|                   9|                 [A]|{Category1->1, Category2->0, Category3->0}|
 -------------------- -------------------- ------------------------------------------ 

Basically, I want to return an output map which has values 1 if elements in the list in data["my_list"] are present in category_Mapping. Anyway I can do this?

Edit: Its been around 5 hrs and nobody has answered. Could someone please help me with this?

CodePudding user response:

You can try this
I have done this way on spark local mode and not on the cluster

// Assuming that your dataframe is stored in a variable called df

// Define a function which will return your map based on the given array in the colum n 'my_list'

def function(lst: mutable.WrappedArray[String]): Map[String, Int] = {
    var map: scala.collection.mutable.Map[String, Int] = scala.collection.mutable.Map("Category1" -> 0, "Category2" -> 0, "Category3" -> 0)
    lst.foreach { l =>
      map.keys.foreach { key =>
        if (Map("Category1" -> Array("A", "B"), "Category2" -> Array("C"), "Category3" -> Array("B", "D"))(key).contains(l))
            map(key) = 1
      }
    }
    map.toMap
}

// now you can define a udf which will just call the above defined function

val output = udf { (lst: mutable.WrappedArray[String]) => {
    function(lst)
  }
}

// now you can call the udf on the column 'my_list'

df.withColumn("output", output(col("my_list"))).show(false)

// The output will be as given below

 ------------- --------- ------------------------------------------------ 
|Serial Number|my_list  |output                                          |
 ------------- --------- ------------------------------------------------ 
|1            |[A]      |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|2            |[A]      |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|3            |[A, B]   |[Category2 -> 0, Category1 -> 1, Category3 -> 1]|
|4            |[C]      |[Category2 -> 1, Category1 -> 0, Category3 -> 0]|
|5            |[A]      |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|6            |[A, C]   |[Category2 -> 1, Category1 -> 1, Category3 -> 0]|
|7            |[B]      |[Category2 -> 0, Category1 -> 1, Category3 -> 1]|
|8            |[A, B, C]|[Category2 -> 1, Category1 -> 1, Category3 -> 1]|
|9            |[A]      |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
 ------------- --------- ------------------------------------------------ 

To get the keys of the map in output column based on the category_Mapping, we can broadcast the category_Mapping variable using sparkContext and use it in the function to dynamically define the output map. it can be done as follows:

val spark = SparkSession.builder().master("local[*]").getOrCreate()

val category_Mapping = Map("Category1" -> Array("A", "B"), "Category2" -> Array("C"), "Category3" -> Array("B", "D"))

val broadcast_category_Mapping = spark.sparkContext.broadcast(category_Mapping)

// In this function the output map is not hardcoded and it is being defined from the category_Mapping

def function(lst: mutable.WrappedArray[String]): Map[String, Int] = {
    var map: scala.collection.mutable.Map[String, Int] = scala.collection.mutable.Map()
    lst.foreach { l =>
        broadcast_category_Mapping.value.keys.foreach { key =>
            if(!map.contains(key))
                map(key) = 0
            if (broadcast_category_Mapping.value(key).contains(l))
                map(key) = 1
        }
    }
    map.toMap
}

// Rest of the code remains the same
  • Related