Longest Run without being in the UK


I have the following SparkDataframe

val inputDf = List(
    ("1", "1", "UK", "Spain", "2022-01-01"),
    ("1", "2", "Spain", "Germany", "2022-01-02"),
    ("1", "3", "Germany", "China", "2022-01-03"),
    ("1", "4", "China", "France", "2022-01-04"),
    ("1", "5", "France", "Spain", "2022-01-05"),
    ("1", "6", "Spain", "Italy", "2022-01-09"),
    ("1", "7", "Italy", "UK", "2022-01-14"),
    ("1", "8", "UK", "USA", "2022-01-15"),
    ("1", "9", "USA", "Canada", "2022-01-16"),
    ("1", "10", "Canada", "UK", "2022-01-17"),
    ("2", "16", "USA", "Finland", "2022-01-11"),
    ("2", "17", "Finland", "Russia", "2022-01-12"),
    ("2", "18", "Russia", "Turkey", "2022-01-13"),
    ("2", "19", "Turkey", "Japan", "2022-01-14"),
    ("2", "20", "Japan", "UK", "2022-01-15"),
  ).toDF("passengerId", "flightId", "from", "to", "date")

I would like to get the longest run for each passengers without being in the UK. So for example in the case of passenger 1 his itinerary was UK>Spain>Germany>China>France>Spain>Italy>UK>USA> Canada>UK>Finland>Russia>Turkey>Japan>Spain>Germany>China>France>Spain>Italy>UK>USA>Canada>UK. Therefore the longest run would be 10.

I first merge the column from and to using the following code.

  val passengerWithCountries = inputDf.groupBy("passengerId")
      // concat is for concatenate two lists of strings from columns "from" and "to"
        // collect list gathers all values from the given column into array


 ----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ 
|passengerId|countries                                                                                                                                                                                                   |
 ----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ 
|1          |[UK, Spain, Germany, China, France, Spain, Italy, UK, USA, Canada, UK, Finland, Russia, Turkey, Japan, Spain, Germany, China, France, Spain, Italy, UK, USA, Canada, UK, Finland, Russia, Turkey, Japan, UK]|
|2          |[USA, Finland, Russia, Turkey, Japan, Finland, Russia, Turkey, Japan, UK]                                                                                                                                   |
 ----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ 

The solution I have tried is the following. However, I since the value of my column are Array[String] and not String it does not work.

.withColumn("countries_new", explode(split(Symbol("countries"), "UK,")))
.withColumn("journey_outside_UK", size(split(Symbol("countries"), ",")))
.agg(max(Symbol("journey_outside_UK")) as "longest_run").show()

I an looking to have the following output:

 ----------- ----------- 
 ----------- ----------- 
|1          |10         |
|2          |5          |
 ----------- ----------- 

Please let me know if you have a solution.

CodePudding user response:

// Added some edge cases:
//   passengerId=3: just one itinary from UK to non-UK, longest run must be 1
//   passengerId=4: just one itinary from non-UK to UK, longest run must be 1
//   passengerId=5: just one itinary from UK to UK, longest run must be 0
//   passengerId=6: one itinary from UK to UK, followed by UK to non-UK, longest run must be 1
val inputDf = List(
    ("1", "1", "UK", "Spain", "2022-01-01"),
    ("1", "2", "Spain", "Germany", "2022-01-02"),
    ("1", "3", "Germany", "China", "2022-01-03"),
    ("1", "4", "China", "France", "2022-01-04"),
    ("1", "5", "France", "Spain", "2022-01-05"),
    ("1", "6", "Spain", "Italy", "2022-01-09"),
    ("1", "7", "Italy", "UK", "2022-01-14"),
    ("1", "8", "UK", "USA", "2022-01-15"),
    ("1", "9", "USA", "Canada", "2022-01-16"),
    ("1", "10", "Canada", "UK", "2022-01-17"),
    ("2", "16", "USA", "Finland", "2022-01-11"),
    ("2", "17", "Finland", "Russia", "2022-01-12"),
    ("2", "18", "Russia", "Turkey", "2022-01-13"),
    ("2", "19", "Turkey", "Japan", "2022-01-14"),
    ("2", "20", "Japan", "UK", "2022-01-15"),
    ("3", "21", "UK", "Spain", "2022-01-01"),
    ("4", "22", "Spain", "UK", "2022-01-01"),
    ("5", "23", "UK", "UK", "2022-01-01"),
    ("6", "24", "UK", "UK", "2022-01-01"),
    ("6", "25", "UK", "Spain", "2022-01-02"),
    ("7", "25", "Spain", "Germany", "2022-01-02"),
  ).toDF("passengerId", "flightId", "from", "to", "date")

import org.apache.spark.sql.expressions.Window

// Declare window for analytic functions
val w = Window.partitionBy("passengerId").orderBy("date")

// Use analytic function to partition rows by UK-...-UK itinaries
val ukArrivals = inputDf.withColumn("newUK", sum(expr("case when from = 'UK' then 1 else 0 end")).over(w))
 ----------- -------- ------- ------- ---------- ----- 
|passengerId|flightId|   from|     to|      date|newUK|
 ----------- -------- ------- ------- ---------- ----- 
|          1|       1|     UK|  Spain|2022-01-01|    1|
|          1|       2|  Spain|Germany|2022-01-02|    1|
|          1|       3|Germany|  China|2022-01-03|    1|
|          1|       4|  China| France|2022-01-04|    1|
|          1|       5| France|  Spain|2022-01-05|    1|
|          1|       6|  Spain|  Italy|2022-01-09|    1|
|          1|       7|  Italy|     UK|2022-01-14|    1|
|          1|       8|     UK|    USA|2022-01-15|    2|
|          1|       9|    USA| Canada|2022-01-16|    2|
|          1|      10| Canada|     UK|2022-01-17|    2|
|          2|      16|    USA|Finland|2022-01-11|    0|
|          2|      17|Finland| Russia|2022-01-12|    0|
|          2|      18| Russia| Turkey|2022-01-13|    0|
|          2|      19| Turkey|  Japan|2022-01-14|    0|
|          2|      20|  Japan|     UK|2022-01-15|    0|
|          3|      21|     UK|  Spain|2022-01-01|    1|
|          4|      22|  Spain|     UK|2022-01-01|    0|
|          5|      23|     UK|     UK|2022-01-01|    1|
|          6|      24|     UK|     UK|2022-01-01|    1|
|          6|      25|     UK|  Spain|2022-01-02|    2|
 ----------- -------- ------- ------- ---------- ----- 

// Calculate longest runs outside UK
val runs = (
    .groupBy("passengerId", "newUK") // for each UK-...-UK itinary
                    when 'UK' not in (from,to) then 1 -- count all nonUK countries, except for first one
                    when from = to then -1            -- special case for UK-UK itinaries
                    else 0                            -- don't count itinaries from/to UK
        )   1   // count first non-UK country
 ----------- ---------------------- 
 ----------- ---------------------- 
|          1|                     6|
|          2|                     5|
|          3|                     1|
|          4|                     1|
|          5|                     0|
|          6|                     1|
 ----------- ---------------------- 
