I have a dataframe (df1), which look something like this:
-------------------------
|foods |
-------------------------
|[Apple, Apple, Banana] |
|[Apple, Carrot, Broccoli]|
|[Spinach] |
-------------------------
and I want to do a look up in another dataframe (df2) that looks like this:
------------------
|food |category |
------------------
|Apple |Fruit |
|Carrot |Vegetable|
|Broccoli|Vegetable|
|Banana |Fruit |
|Spinach |Vegetable|
------------------
and have the resulting dataframe look like this:
------------------------- ----------------------------- ---------
|foods |categories |has fruit|
------------------------- ----------------------------- ---------
|[Apple, Apple, Banana] |[Fruit, Fruit, Fruit] |true |
|[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true |
|[Spinach] |[Vegetable] |false |
------------------------- ----------------------------- ---------
How would I be able to do this in Spark/Scala? I am new to Scala, so an explanation of the code may be helpful as well. Thank you!
This is the code I am currently working with, but I am getting a org.apache.spark.SparkException: Task not serializable
error Caused by: java.io.NotSerializableException: org.apache.spark.sql.Column
.
var schema = df1.schema("foods").dataType
def func = udf((x: Seq[String]) => {
x.map(x => df2.filter(col("food") === x).select(col("category")).head().getString(0))
}, schema)
df1.withColumn("categories", func($"foods")).show()
I would appreciate some help. The code doesn't need to be clean. Thank you.
I've tried turning df2 into a Map, and changed the code a bit:
var mappy = df2.map{ r => (r.getString(0), r.getString(1))}.collect.toMap
var schema = df1.schema("foods").dataType
def func = udf((x: Seq[String]) => {
x.map(x => mappy.getOrElse(x, ""))
}, schema)
df1.withColumn("categories", func($"foods")).show()
However, now I get this error:
java.lang.StackOverflowError
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1189)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
(((repeats)))
Sorry for the messy code. This is for analysis, not production. Thanks again!
CodePudding user response:
I prepared your input DataFrames like this:
// below import is needed for you to use toDF() method to create DataFrames on the fly.
import spark.implicits._
val df1 = List(
(List("Apple", "Apple", "Banana")),
(List("Apple", "Carrot", "Broccoli")),
(List("Spinach"))
).toDF("foods")
val df2 = List(
("Apple", "Fruit"),
("Carrot", "Vegetable"),
("Broccoli", "Vegetable"),
("Banana", "Fruit"),
("Spinach", "Vegetable")
).toDF("food", "category")
My simple solution using DataFrames and Aggregate Functions using groupBy can get you the desired output (as follows) :
// imports
import org.apache.spark.sql.functions._
// code
val df1_altered = df1
// explode() : creates a new row for every element in the Array
.withColumn("each_food_from_list", explode(col("foods")))
// df1_altered.show()
val df2_altered = df2
.withColumn(
"has_fruit",
when(col("category").equalTo(lit("Fruit")), true).otherwise(false)
)
// df2_altered.show()
df1_altered
.join(df2_altered, df1_altered("each_food_from_list") === df2_altered("food"), "inner")
//groupBy() : acts as the opposite of explode() by grouping multiple rows together as one based on a column; with the specified mandatory Aggregate function(s)
.groupBy(col("foods"))
.agg(
collect_list(col("category")) as "categories",
max(col("has_fruit")) as "has fruit"
)
.show(false)
// ------------------------- ----------------------------- ---------
// |foods |categories |has fruit|
// ------------------------- ----------------------------- ---------
// |[Apple, Apple, Banana] |[Fruit, Fruit, Fruit] |true |
// |[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true |
// |[Spinach] |[Vegetable] |false |
// ------------------------- ----------------------------- ---------
EDIT: With duplicates present in First DataFrame, you can use an ID column that you generated and then groupBy with both columns
Have mentioned the 3 lines of code change as change1, change2 and change3 below:
// change1
val df1_with_id = df1.withColumn("id", monotonically_increasing_id())
// change2
val df1_altered = df1_with_id
.withColumn("each_food_from_list", explode(col("foods")))
// df1_altered.show()
val df2_altered = df2
.withColumn(
"has_fruit",
when(col("category").equalTo(lit("Fruit")), true).otherwise(false)
)
// df2_altered.show()
df1_altered
.join(df2_altered, df1_altered("each_food_from_list") === df2_altered("food"), "inner")
// change3
.groupBy(col("id"),col("foods"))
.agg(
collect_list(col("category")) as "categories",
max(col("has_fruit")) as "has fruit"
)
.show(false)
// --- ------------------------- ----------------------------- ---------
//|id |foods |categories |has fruit|
// --- ------------------------- ----------------------------- ---------
//|0 |[Apple, Apple, Banana] |[Fruit, Fruit, Fruit] |true |
//|1 |[Apple, Apple, Banana] |[Fruit, Fruit, Fruit] |true |
//|3 |[Spinach] |[Vegetable] |false |
//|2 |[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true |
// --- ------------------------- ----------------------------- ---------