I have the following SparkDataframe
val inputDf = List(
("1", "1", "UK", "Spain", "2022-01-01"),
("1", "2", "Spain", "Germany", "2022-01-02"),
("1", "3", "Germany", "China", "2022-01-03"),
("1", "4", "China", "France", "2022-01-04"),
("1", "5", "France", "Spain", "2022-01-05"),
("1", "6", "Spain", "Italy", "2022-01-09"),
("1", "7", "Italy", "UK", "2022-01-14"),
("1", "8", "UK", "USA", "2022-01-15"),
("1", "9", "USA", "Canada", "2022-01-16"),
("1", "10", "Canada", "UK", "2022-01-17"),
("2", "16", "USA", "Finland", "2022-01-11"),
("2", "17", "Finland", "Russia", "2022-01-12"),
("2", "18", "Russia", "Turkey", "2022-01-13"),
("2", "19", "Turkey", "Japan", "2022-01-14"),
("2", "20", "Japan", "UK", "2022-01-15"),
).toDF("passengerId", "flightId", "from", "to", "date")
I would like to get the longest run for each passengers without being in the UK. So for example in the case of passenger 1 his itinerary was UK>Spain>Germany>China>France>Spain>Italy>UK>USA> Canada>UK>Finland>Russia>Turkey>Japan>Spain>Germany>China>France>Spain>Italy>UK>USA>Canada>UK. Therefore the longest run would be 10.
I first merge the column from and to using the following code.
val passengerWithCountries = inputDf.groupBy("passengerId")
.agg(
// concat is for concatenate two lists of strings from columns "from" and "to"
concat(
// collect list gathers all values from the given column into array
collect_list(col("from")),
collect_list(col("to"))
).name("countries")
)
Output:
----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|passengerId|countries |
----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|1 |[UK, Spain, Germany, China, France, Spain, Italy, UK, USA, Canada, UK, Finland, Russia, Turkey, Japan, Spain, Germany, China, France, Spain, Italy, UK, USA, Canada, UK, Finland, Russia, Turkey, Japan, UK]|
|2 |[USA, Finland, Russia, Turkey, Japan, Finland, Russia, Turkey, Japan, UK] |
----------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
The solution I have tried is the following. However, I since the value of my column are Array[String] and not String it does not work.
passengerWithCountries
.withColumn("countries_new", explode(split(Symbol("countries"), "UK,")))
.withColumn("journey_outside_UK", size(split(Symbol("countries"), ",")))
.groupBy("passengerId")
.agg(max(Symbol("journey_outside_UK")) as "longest_run").show()
I an looking to have the following output:
----------- -----------
|passengerId|longest_run|
----------- -----------
|1 |10 |
|2 |5 |
----------- -----------
Please let me know if you have a solution.
CodePudding user response:
// Added some edge cases:
// passengerId=3: just one itinary from UK to non-UK, longest run must be 1
// passengerId=4: just one itinary from non-UK to UK, longest run must be 1
// passengerId=5: just one itinary from UK to UK, longest run must be 0
// passengerId=6: one itinary from UK to UK, followed by UK to non-UK, longest run must be 1
val inputDf = List(
("1", "1", "UK", "Spain", "2022-01-01"),
("1", "2", "Spain", "Germany", "2022-01-02"),
("1", "3", "Germany", "China", "2022-01-03"),
("1", "4", "China", "France", "2022-01-04"),
("1", "5", "France", "Spain", "2022-01-05"),
("1", "6", "Spain", "Italy", "2022-01-09"),
("1", "7", "Italy", "UK", "2022-01-14"),
("1", "8", "UK", "USA", "2022-01-15"),
("1", "9", "USA", "Canada", "2022-01-16"),
("1", "10", "Canada", "UK", "2022-01-17"),
("2", "16", "USA", "Finland", "2022-01-11"),
("2", "17", "Finland", "Russia", "2022-01-12"),
("2", "18", "Russia", "Turkey", "2022-01-13"),
("2", "19", "Turkey", "Japan", "2022-01-14"),
("2", "20", "Japan", "UK", "2022-01-15"),
("3", "21", "UK", "Spain", "2022-01-01"),
("4", "22", "Spain", "UK", "2022-01-01"),
("5", "23", "UK", "UK", "2022-01-01"),
("6", "24", "UK", "UK", "2022-01-01"),
("6", "25", "UK", "Spain", "2022-01-02"),
("7", "25", "Spain", "Germany", "2022-01-02"),
).toDF("passengerId", "flightId", "from", "to", "date")
import org.apache.spark.sql.expressions.Window
// Declare window for analytic functions
val w = Window.partitionBy("passengerId").orderBy("date")
// Use analytic function to partition rows by UK-...-UK itinaries
val ukArrivals = inputDf.withColumn("newUK", sum(expr("case when from = 'UK' then 1 else 0 end")).over(w))
----------- -------- ------- ------- ---------- -----
|passengerId|flightId| from| to| date|newUK|
----------- -------- ------- ------- ---------- -----
| 1| 1| UK| Spain|2022-01-01| 1|
| 1| 2| Spain|Germany|2022-01-02| 1|
| 1| 3|Germany| China|2022-01-03| 1|
| 1| 4| China| France|2022-01-04| 1|
| 1| 5| France| Spain|2022-01-05| 1|
| 1| 6| Spain| Italy|2022-01-09| 1|
| 1| 7| Italy| UK|2022-01-14| 1|
| 1| 8| UK| USA|2022-01-15| 2|
| 1| 9| USA| Canada|2022-01-16| 2|
| 1| 10| Canada| UK|2022-01-17| 2|
| 2| 16| USA|Finland|2022-01-11| 0|
| 2| 17|Finland| Russia|2022-01-12| 0|
| 2| 18| Russia| Turkey|2022-01-13| 0|
| 2| 19| Turkey| Japan|2022-01-14| 0|
| 2| 20| Japan| UK|2022-01-15| 0|
| 3| 21| UK| Spain|2022-01-01| 1|
| 4| 22| Spain| UK|2022-01-01| 0|
| 5| 23| UK| UK|2022-01-01| 1|
| 6| 24| UK| UK|2022-01-01| 1|
| 6| 25| UK| Spain|2022-01-02| 2|
----------- -------- ------- ------- ---------- -----
// Calculate longest runs outside UK
val runs = (
ukArrivals
.groupBy("passengerId", "newUK") // for each UK-...-UK itinary
.agg((
sum(
expr("""
case
when 'UK' not in (from,to) then 1 -- count all nonUK countries, except for first one
when from = to then -1 -- special case for UK-UK itinaries
else 0 -- don't count itinaries from/to UK
end""")
) 1 // count first non-UK country
).as("notUK"))
.groupBy("passengerId")
.agg(max("notUK").as("longest_run_outside_UK"))
)
runs.orderBy("passengerId").show
----------- ----------------------
|passengerId|longest_run_outside_UK|
----------- ----------------------
| 1| 6|
| 2| 5|
| 3| 1|
| 4| 1|
| 5| 0|
| 6| 1|
----------- ----------------------