Home > Enterprise >  Spark Join for Each Item in List
Spark Join for Each Item in List

Time:09-22

I have a Spark Dataset like

 ---------- ------- ---- --- -------------- 
|        _1|     _2|  _3| _4|            _5|
 ---------- ------- ---- --- -------------- 
|      null|1111111|null| 15|       [98765]|
|      null|2222222|null| 16|[97008, 98765]|
|6436334664|3333333|null| 15|       [97008]|
|2356242642|4444444|null| 11|       [97008]|
 ---------- ------- ---- --- -------------- 

Where the the fifth column is a list of zipcodes associated to that row. I have another table that has unique rows for each zipcode and a corresponding longitude and latitude. I want to create a table like

 ---------- ------- ---- --- -------------- -----------------------------------
|        _1|     _2|  _3| _4|            _5|                                _6|
 ---------- ------- ---- --- -------------- ---------------------------------- 
|3572893528|1111111|null| 15|       [98765]| [(54.12,-80.53)]                 |
|5325232523|2222222|null| 16|[98765, 97008]| [(54.12,-80.53), (44.12,-75.11)] |
|6436334664|3333333|null| 15|       [97008]| [(54.12,-80.53)]                 | 
|2356242642|4444444|null| 11|       [97008]| [(54.12,-80.53)]                 |
 ---------- ------- ---- --- -------------- ---------------------------------- 

where the sixth column is the coordinates for the zips in the sequence of the fifth column.

I have tried to just filter the zipcode table everytime I need coordinates but I get a NPE, I think because of similar reasons detailed in this question. If I try to collect the zipcode table before filtering it I run out of memory.

I am using Scala and I got the original Dataset using Spark SQL in a Spark job. Any solutions would be appreciated, thank you.

CodePudding user response:

Let's assume (the comment on your question holds true and) we have have two datasets (simplifying your example), ds and ds2, respectively:

 --- -------------- 
|_1 |_2            |
 --- -------------- 
|15 |[98765]       |
|16 |[97008, 98765]|
|15 |[97008]       |
|15 |[97008]       |
 --- -------------- 
 ----- --------------- 
|_2   |_3             |
 ----- --------------- 
|98765|{54.12, -80.53}|
|97008|{44.12, -75.11}|
 ----- --------------- 

The idea is to create a unique ID (so we can join later), explode the dataset, then join to get the coordinates per unique ID, finally join the tables again.

Creating a unique ID:

ds = ds.withColumn("id", monotonically_increasing_id())

Then create the mapping table that contains id and your zip codes:

val map = ds
  .withColumn("_2", explode(col("_2")))
  .join(ds2, Seq("_2"), "left")
  .groupBy("id").agg(collect_set(col("_3")))

Finally join back on the main table:

ds = ds.join(map, Seq("id"))

Final output:

 --- -------------- ---------------------------------- 
|_1 |_2            |collect_set(_3)                   |
 --- -------------- ---------------------------------- 
|15 |[98765]       |[{54.12, -80.53}]                 |
|16 |[97008, 98765]|[{54.12, -80.53}, {44.12, -75.11}]|
|15 |[97008]       |[{44.12, -75.11}]                 |
|15 |[97008]       |[{44.12, -75.11}]                 |
 --- -------------- ---------------------------------- 

Good luck!

  • Related