Home > Software design >  Pyspark window function with conditions to round number of travelers
Pyspark window function with conditions to round number of travelers

Time:10-09

I am using Pyspark and I would like to create a function which performs the following operation:

Given data describing the transactions of train users:

 ---- ---------- -------- ----- 
|date|total_trav|num_trav|order|
 ---- ---------- -------- ----- 
|   1|         9|     2.7|    1|
|   1|         9|     1.3|    2|
|   1|         9|     1.3|    3|
|   1|         9|     1.3|    4|
|   1|         9|     1.2|    5|
|   1|         9|     1.1|    6|
|   2|         9|     2.7|    1|
|   2|         9|     1.3|    2|
|   2|         9|     1.3|    3|
|   2|         9|     1.3|    4|
|   2|         9|     1.2|    5|
|   2|         9|     1.1|    6|
 ---- ---------- -------- ----- 

I would like to round the numbers of the num_trav column based on the order given in the order column, while grouping by date to obtain the trav_res column. The logic behind it would be something like:

  • We group the data by date
  • For each grouped data (where date=1 and date=2) we have to round numbers to the ceil (ceil(num_trav)) always (it doesn´t matter their value, always rounded to the ceil). But taking into account that we have a maximum amount of travelers by group (total_trav) which in this case, for both groups would be 9.
  • This is where order column takes place. You need to start rounding in the order given by that column and checking the amount of travelers you have left for that group.

For example, let's consider this result dataframe and see how the trav_res column is formed:

 ---- ---------- -------- ----- -------- 
|date|total_trav|num_trav|order|trav_res|
 ---- ---------- -------- ----- -------- 
|   1|         9|     2.7|    1|       3|
|   1|         9|     1.3|    2|       2|
|   1|         9|     1.3|    3|       2|
|   1|         9|     1.3|    4|       2|
|   1|         9|     1.2|    5|       0|
|   1|         9|     1.1|    6|       0|
|   2|         9|     2.7|    1|       3|
|   2|         9|     1.3|    2|       2|
|   2|         9|     1.3|    3|       2|
|   2|         9|     1.3|    4|       2|
|   2|         9|     1.2|    5|       0|
|   2|         9|     1.1|    6|       0|
 ---- ---------- -------- ----- -------- 

In the example above, when you group by date, you will have 2 groups which the max amount of travelers is 9 (total_trav column). For group 1 for example yo will start rounding the num_trav=2.7 to 3 (trav_res column), then the num_trav=1.3 to 2, then num_trav=1.3 to 2, the num_trav=1.3 to 2 (this is following the order given), and then for the next ones you have no travelers left, so it doesn't really matter the number they have as there are no travelers left, so they will get trav_res=0 in both cases.

I have tried some udf functions, but thy seem not to do the job.

CodePudding user response:

You can first apply F.ceil to all rows in num_trav, then create cumsum column based on ceiling values, and then set the ceiling values to zero when cumsum exceeds total_trav as in the code below

# create dataframe
import pyspark.sql.functions as F
from pyspark.sql import Window

data = [(1, 9, 2.7, 1),
        (1, 9, 1.3, 2),
        (1, 9, 1.3, 3),
        (1, 9, 1.3, 4),
        (1, 9, 1.2, 5),
        (1, 9, 1.1, 6),
        (2, 9, 2.7, 1),
        (2, 9, 1.3, 2),
        (2, 9, 1.3, 3),
        (2, 9, 1.3, 4),
        (2, 9, 1.2, 5),
        (2, 9, 1.1, 6)]

df = spark.createDataFrame(data, schema=["date", "total_trav", "num_trav", "order"])

# create ceiling column
df = df.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil"))
               .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

CodePudding user response:

The solution is based on @AnnaK. answer, with a little addition to it. This way it takes into accountthat the total number of travelers (total_trav) has to be used, not more, not less.

# create ceiling column
df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil")
                     ).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
                      1)
              .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))
  • Related