Home > Net >  PySpark - Cumulative sum with limits
PySpark - Cumulative sum with limits


I have a dataframe as follows:

 ------- ---------- ----- 
|user_id|      date|valor|
 ------- ---------- ----- 
|      1|2022-01-01|    0|
|      1|2022-01-02|    0|
|      1|2022-01-03|    1|
|      1|2022-01-04|    1|
|      1|2022-01-05|    1|
|      1|2022-01-06|    0|
|      1|2022-01-07|    0|
|      1|2022-01-08|    0|
|      1|2022-01-09|    1|
|      1|2022-01-10|    1|
|      1|2022-01-11|    1|
|      1|2022-01-12|    0|
|      1|2022-01-13|    0|
|      1|2022-01-14|   -1|
|      1|2022-01-15|   -1|
|      1|2022-01-16|   -1|
|      1|2022-01-17|   -1|
|      1|2022-01-18|   -1|
|      1|2022-01-19|   -1|
|      1|2022-01-20|    0|
 ------- ---------- ----- 

The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column, the main problem here is that my score cant be under 1 and cant be over 5, so the sum bust always stay on the range and don't lose the last value so I can compute it right. So what I expect is this:

 ------- ---------- ----- ----- 
|user_id|      date|valor|score|
 ------- ---------- ----- ----- 
|      1|2022-01-01|    0|    3|
|      1|2022-01-02|    0|    3|
|      1|2022-01-03|    1|    4|
|      1|2022-01-04|    1|    5|
|      1|2022-01-05|    1|    5|
|      1|2022-01-06|    0|    5|
|      1|2022-01-07|    0|    5|
|      1|2022-01-08|    0|    5|
|      1|2022-01-09|    1|    5|
|      1|2022-01-10|   -1|    4|
|      1|2022-01-11|   -1|    3|
|      1|2022-01-12|    0|    3|
|      1|2022-01-13|    0|    3|
|      1|2022-01-14|   -1|    2|
|      1|2022-01-15|   -1|    1|
|      1|2022-01-16|    1|    2|
|      1|2022-01-17|   -1|    1|
|      1|2022-01-18|   -1|    1|
|      1|2022-01-19|    1|    2|
|      1|2022-01-20|    0|    2|
 ------- ---------- ----- ----- 

So far I've done a window to rank the column valor so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don't know how to keep the sum in valor_ in the range (1:5):

 ------- ---------- ---- ----- ------ 
|user_id|      date|rank|valor|valor_|
 ------- ---------- ---- ----- ------ 
|      1|2022-01-01|   0|    0|     0|
|      1|2022-01-02|   0|    0|     0|
|      1|2022-01-03|   1|    1|     1|
|      1|2022-01-04|   2|    1|     1|
|      1|2022-01-05|   3|    1|     1|
|      1|2022-01-06|   0|    0|     0|
|      1|2022-01-07|   0|    0|     0|
|      1|2022-01-08|   0|    0|     0|
|      1|2022-01-09|   1|    1|     1|
|      1|2022-01-10|   2|    1|     1|
|      1|2022-01-11|   3|    1|     1|
|      1|2022-01-12|   0|    0|     0|
|      1|2022-01-13|   0|    0|     0|
|      1|2022-01-14|   1|   -1|    -1|
|      1|2022-01-15|   2|   -1|    -1|
|      1|2022-01-16|   3|   -1|    -1|
|      1|2022-01-17|   4|   -1|    -1|
|      1|2022-01-18|   5|   -1|     0|
|      1|2022-01-19|   6|   -1|     0|

As you can see, the result here is not what I expected:

 ------- ---------- ---- ----- ------ ----- 
|user_id|      date|rank|valor|valor_|score|
 ------- ---------- ---- ----- ------ ----- 
|      1|2022-01-01|   0|    0|     0|    3|
|      1|2022-01-02|   0|    0|     0|    3|
|      1|2022-01-03|   1|    1|     1|    4|
|      1|2022-01-04|   2|    1|     1|    5|
|      1|2022-01-05|   3|    1|     1|    6|
|      1|2022-01-06|   0|    0|     0|    6|
|      1|2022-01-07|   0|    0|     0|    6|
|      1|2022-01-08|   0|    0|     0|    6|
|      1|2022-01-09|   1|    1|     1|    7|
|      1|2022-01-10|   2|    1|     1|    8|
|      1|2022-01-11|   3|    1|     1|    9|
|      1|2022-01-12|   0|    0|     0|    9|
|      1|2022-01-13|   0|    0|     0|    9|
|      1|2022-01-14|   1|   -1|    -1|    8|
|      1|2022-01-15|   2|   -1|    -1|    7|
|      1|2022-01-16|   3|   -1|    -1|    6|
|      1|2022-01-17|   4|   -1|    -1|    5|
|      1|2022-01-18|   5|   -1|     0|    5|
|      1|2022-01-19|   6|   -1|     0|    5|
|      1|2022-01-20|   0|    0|     0|    5|

CodePudding user response:


from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, '2022-01-01',  0),
     (1, '2022-01-02',  0),
     (1, '2022-01-03',  1),
     (1, '2022-01-04',  1),
     (1, '2022-01-05',  1),
     (1, '2022-01-06',  0),
     (1, '2022-01-07',  0),
     (1, '2022-01-08',  0),
     (1, '2022-01-09',  1),
     (1, '2022-01-10',  1),
     (1, '2022-01-11',  1),
     (1, '2022-01-12',  0),
     (1, '2022-01-13',  0),
     (1, '2022-01-14', -1),
     (1, '2022-01-15', -1),
     (1, '2022-01-16', -1),
     (1, '2022-01-17', -1),
     (1, '2022-01-18', -1),
     (1, '2022-01-19', -1),
     (1, '2022-01-20',  0)],
    ['user_id', 'date', 'valor'])


df = df.groupBy('user_id').agg(F.array_sort(F.collect_list(F.array('date', 'valor'))).alias('a'))
df = df.withColumn(
            F.expr("array(struct('' as date, 0 as valor, 3 as cum))"),
            lambda acc, x: F.array_union(
                    F.greatest(F.lit(1), F.least(F.lit(5), x[1].cast('int')   F.element_at(acc, -1)['cum'])).alias('cum')
        lambda x: x['date'] != ''
df = df.selectExpr("user_id", "inline(a)")

#  ------- ---------- ----- --- 
# |user_id|      date|valor|cum|
#  ------- ---------- ----- --- 
# |      1|2022-01-01|    0|  3|
# |      1|2022-01-02|    0|  3|
# |      1|2022-01-03|    1|  4|
# |      1|2022-01-04|    1|  5|
# |      1|2022-01-05|    1|  5|
# |      1|2022-01-06|    0|  5|
# |      1|2022-01-07|    0|  5|
# |      1|2022-01-08|    0|  5|
# |      1|2022-01-09|    1|  5|
# |      1|2022-01-10|    1|  5|
# |      1|2022-01-11|    1|  5|
# |      1|2022-01-12|    0|  5|
# |      1|2022-01-13|    0|  5|
# |      1|2022-01-14|   -1|  4|
# |      1|2022-01-15|   -1|  3|
# |      1|2022-01-16|   -1|  2|
# |      1|2022-01-17|   -1|  1|
# |      1|2022-01-18|   -1|  1|
# |      1|2022-01-19|   -1|  1|
# |      1|2022-01-20|    0|  1|
#  ------- ---------- ----- --- 

CodePudding user response:

tl;dr - complex approach similar to this - consider this as last resort due to its complexity

A python function can keep track of the previous cumulative sum value. The said python function can be used with flatMapValues() to process the data.

Consider the following input data

data1_ls = [(1, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-01-01'), pd.to_datetime('2022-01-20'))]
data2_ls = [(2, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-04-01'), pd.to_datetime('2022-04-30'))]

data1_sdf = spark.sparkContext.parallelize(data1_ls).toDF(['user', 'dt', 'valor']). \
    withColumn('dt', func.col('dt').cast('date'))

data2_sdf = spark.sparkContext.parallelize(data2_ls).toDF(['user', 'dt', 'valor']). \
    withColumn('dt', func.col('dt').cast('date'))

data_sdf = data1_sdf.unionByName(data2_sdf)

#  ---- ---------- ----- 
# |user|        dt|valor|
#  ---- ---------- ----- 
# |   1|2022-01-01|    1|
# |   1|2022-01-02|   -1|
# |   1|2022-01-03|    0|
# |   1|2022-01-04|    1|
# |   1|2022-01-05|    0|
#  ---- ---------- ----- 

We can write a python function that takes the sum and keeps track of it. This function should be shipped to all executors for optimum resource usage.

def cumsum_in_range(groupedRows, initial_value=3):

    res = []
    frstRec = True
    initVal = initial_value

    for row in groupedRows:
        if frstRec:
            # data starts from a static value
            frstRec = False
            cumsum = initVal   row.valor
            cumsum = prev_cumsum   row.valor

            if cumsum > 5:
                cumsum = 5
            elif cumsum < 1:
                cumsum = 1
        prev_cumsum = cumsum  # keeping track of the latest sum for next iteration

        res.append([item for item in row]   [cumsum])
    return res

To use the function to process, we'll use flatMapValues() and groupBy(). The groupBy() partitions the data based on the column provided. We'll also need the data order by the date field for the cumulative sum. So, a sorted() will be used and the date field will be passed as key.

# run the python function and keep only the resulting values
res_vals = data_sdf.rdd. \
    groupBy(lambda gk: gk.user). \
    flatMapValues(lambda r: cumsum_in_range(sorted(r, key=lambda ok: ok.dt))). \

# create schema for the new column in previous dataframe
data_schema = data_sdf.withColumn('dropme', func.lit(None).cast('int')). \
    drop('dropme'). \
    schema. \
    add('cumsum', 'integer')

# create a dataframe with the new values
res_sdf = spark.createDataFrame(res_vals, data_schema)

The res_sdf dataframe will have the cumulative sum column created for each user, based on the python function defined above.

res_sdf. \
    filter(func.col('user') == 1). \
    orderBy(['user', 'dt']). \

#  ---- ---------- ----- ------ 
# |user|        dt|valor|cumsum|
#  ---- ---------- ----- ------ 
# |   1|2022-01-01|    1|     4|
# |   1|2022-01-02|   -1|     3|
# |   1|2022-01-03|    0|     3|
# |   1|2022-01-04|    1|     4|
# |   1|2022-01-05|    0|     4|
# |   1|2022-01-06|    1|     5|
# |   1|2022-01-07|    0|     5|
# |   1|2022-01-08|    1|     5|
# |   1|2022-01-09|    0|     5|
# |   1|2022-01-10|   -1|     4|
# |   1|2022-01-11|   -1|     3|
# |   1|2022-01-12|   -1|     2|
# |   1|2022-01-13|    1|     3|
# |   1|2022-01-14|   -1|     2|
# |   1|2022-01-15|    1|     3|
# |   1|2022-01-16|   -1|     2|
# |   1|2022-01-17|    0|     2|
# |   1|2022-01-18|    1|     3|
# |   1|2022-01-19|    0|     3|
# |   1|2022-01-20|   -1|     2|
#  ---- ---------- ----- ------ 

res_sdf. \
    filter(func.col('user') == 2). \
    orderBy(['user', 'dt']). \

#  ---- ---------- ----- ------ 
# |user|        dt|valor|cumsum|
#  ---- ---------- ----- ------ 
# |   2|2022-04-01|   -1|     2|
# |   2|2022-04-02|    0|     2|
# |   2|2022-04-03|    1|     3|
# |   2|2022-04-04|   -1|     2|
# |   2|2022-04-05|    1|     3|
# |   2|2022-04-06|    0|     3|
# |   2|2022-04-07|    1|     4|
# |   2|2022-04-08|   -1|     3|
# |   2|2022-04-09|    0|     3|
# |   2|2022-04-10|    0|     3|
# |   2|2022-04-11|   -1|     2|
# |   2|2022-04-12|    1|     3|
# |   2|2022-04-13|    0|     3|
# |   2|2022-04-14|    0|     3|
# |   2|2022-04-15|    1|     4|
# |   2|2022-04-16|   -1|     3|
# |   2|2022-04-17|    0|     3|
# |   2|2022-04-18|    0|     3|
# |   2|2022-04-19|    1|     4|
# |   2|2022-04-20|    1|     5|
#  ---- ---------- ----- ------ 
# only showing top 20 rows
  • Related