I have the following data:
val df = Seq(
(1, List("A")),
(2, List("A")),
(3, List("A", "B")),
(4, List("C")),
(5, List("A")),
(6, List("A", "C")),
(7, List("B")),
(8, List("A", "B", "C")),
(9, List("A"))
).toDF("Serial Number", "my_list")
-------------------- --------------------
| Serial Number| my_list|
-------------------- --------------------
| 1| [A]|
| 2| [A]|
| 3| [A,B]|
| 4| [C]|
| 5| [A]|
| 6| [A, C]|
| 7| [B]|
| 8| [A, B, C]|
| 9| [A]|
-------------------- --------------------
I have a map
val category_Mapping = Map("Category1" -> [A, B],
"Category2" -> [C],
"Category3" -> [B, D])
I want to look for each list element in data["my_list"] and return an output map for each data["Serial Number"], in the following way:
-------------------- -------------------- ------------------------------------------
| Serial Number| my_list| output |
-------------------- -------------------- ------------------------------------------
| 1| [A]|{Category1->1, Category2->0, Category3->0}|
| 2| [A]|{Category1->1, Category2->0, Category3->0}|
| 3| [A,B]|{Category1->1, Category2->0, Category3->1}|
| 4| [C]|{Category1->0, Category2->1, Category3->0}|
| 5| [A]|{Category1->1, Category2->0, Category3->0}|
| 6| [A, C]|{Category1->1, Category2->1, Category3->0}|
| 7| [B]|{Category1->1, Category2->0, Category3->1}|
| 8| [A, B, C]|{Category1->1, Category2->1, Category3->1}|
| 9| [A]|{Category1->1, Category2->0, Category3->0}|
-------------------- -------------------- ------------------------------------------
Basically, I want to return an output map which has values 1 if elements in the list in data["my_list"] are present in category_Mapping. Anyway I can do this?
Edit: Its been around 5 hrs and nobody has answered. Could someone please help me with this?
CodePudding user response:
You can try this
I have done this way on spark local mode and not on the cluster
// Assuming that your dataframe is stored in a variable called df
// Define a function which will return your map based on the given array in the colum n 'my_list'
def function(lst: mutable.WrappedArray[String]): Map[String, Int] = {
var map: scala.collection.mutable.Map[String, Int] = scala.collection.mutable.Map("Category1" -> 0, "Category2" -> 0, "Category3" -> 0)
lst.foreach { l =>
map.keys.foreach { key =>
if (Map("Category1" -> Array("A", "B"), "Category2" -> Array("C"), "Category3" -> Array("B", "D"))(key).contains(l))
map(key) = 1
}
}
map.toMap
}
// now you can define a udf which will just call the above defined function
val output = udf { (lst: mutable.WrappedArray[String]) => {
function(lst)
}
}
// now you can call the udf on the column 'my_list'
df.withColumn("output", output(col("my_list"))).show(false)
// The output will be as given below
------------- --------- ------------------------------------------------
|Serial Number|my_list |output |
------------- --------- ------------------------------------------------
|1 |[A] |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|2 |[A] |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|3 |[A, B] |[Category2 -> 0, Category1 -> 1, Category3 -> 1]|
|4 |[C] |[Category2 -> 1, Category1 -> 0, Category3 -> 0]|
|5 |[A] |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
|6 |[A, C] |[Category2 -> 1, Category1 -> 1, Category3 -> 0]|
|7 |[B] |[Category2 -> 0, Category1 -> 1, Category3 -> 1]|
|8 |[A, B, C]|[Category2 -> 1, Category1 -> 1, Category3 -> 1]|
|9 |[A] |[Category2 -> 0, Category1 -> 1, Category3 -> 0]|
------------- --------- ------------------------------------------------
To get the keys of the map in output column based on the category_Mapping, we can broadcast the category_Mapping variable using sparkContext and use it in the function to dynamically define the output map. it can be done as follows:
val spark = SparkSession.builder().master("local[*]").getOrCreate()
val category_Mapping = Map("Category1" -> Array("A", "B"), "Category2" -> Array("C"), "Category3" -> Array("B", "D"))
val broadcast_category_Mapping = spark.sparkContext.broadcast(category_Mapping)
// In this function the output map is not hardcoded and it is being defined from the category_Mapping
def function(lst: mutable.WrappedArray[String]): Map[String, Int] = {
var map: scala.collection.mutable.Map[String, Int] = scala.collection.mutable.Map()
lst.foreach { l =>
broadcast_category_Mapping.value.keys.foreach { key =>
if(!map.contains(key))
map(key) = 0
if (broadcast_category_Mapping.value(key).contains(l))
map(key) = 1
}
}
map.toMap
}
// Rest of the code remains the same