Home > OS >  Create column that shows previous earned salary (groupedBy)
Create column that shows previous earned salary (groupedBy)

Time:10-07

I'm trying to build a replacement for the .shift() function of Python. I'm pretty close, but need the final touch to make everything work. This would be implementing a right form of GroupBy.

I have a dataframe like this one:

month userid amount counterparty_iban
1 John Jones 2400 ENG12345
4 John Jones 1200 ENG12345
8 John Jones 2300 ENG12345
5 John Jones 3000 AM12345
9 John Jones 5000 AM12345
12 John Jones 6000 AM12345
1 Joe Jones 1200 AM12345
2 Joe Jones 2400 AM12345
3 Joe Jones 5000 AM12345

I need to change the dataframe above to the format of the table below:

month userid amount counterparty_iban previous_salary
1 John Jones 2400 ENG12345 null
4 John Jones 1200 ENG12345 2400
8 John Jones 2300 ENG12345 1200
5 John Jones 3000 AM12345 null
9 John Jones 5000 AM12345 3000
12 John Jones 6000 AM12345 5000
1 Joe Jones 1200 AM12345 null
2 Joe Jones 2400 AM12345 1200
3 Joe Jones 5000 AM12345 2400

Here is the code to create the input dataframe:

columns = ["month", "userid", 'exactoriginalamount', 'counterparty_iban']
data = [("1", "John Jones", "2400", 'ENG12345'),
        ("4", "John Jones", "1200", 'ENG12345'),
        ("8", "John Jones", "2300", 'ENG12345'),
        ("5", "John Jones", "3000", 'AM12345'),
        ("9", "John Jones", "5000", 'AM12345'),
        ("12", "John Jones", "6000", 'AM12345'),
        ("1", "Joe Jones", "1200", 'AM12345'),
        ("2", "Joe Jones", "2400", 'AM12345'),
        ("3", "Joe Jones", "5000", 'AM12345')]

df = spark.createDataFrame(data=data, schema=columns)

I've been trying numerous applications with the following code:

w = Window().partitionBy().orderBy(F.col('userid'))
df = df.withColumn('previous_salary', F.lag('exactoriginalamount', 1).over(w))

However, I somehow need to groupby "userid" and "counterparty_iban" so that the different "previous_salary" column displayed the right data.

CodePudding user response:

You need to properly describe partitions. In the example output, I see that you want to make partitions (windows) based on "userid" and "counterparty_iban". The function lag will be run in these partitions separately. In the script below, I also used cast("long") for the "month" column, as using your code the column "month" is originally created of type string (ordering by string column would return different sort order than ordering by number).

w = Window.partitionBy("userid", "counterparty_iban").orderBy(F.col("month").cast("long"))
df = df.withColumn("previous_salary", F.lag("exactoriginalamount").over(w))

df.show()
#  ----- ---------- ------------------- ----------------- --------------- 
# |month|    userid|exactoriginalamount|counterparty_iban|previous_salary|
#  ----- ---------- ------------------- ----------------- --------------- 
# |    1| Joe Jones|               1200|          AM12345|           null|
# |    2| Joe Jones|               2400|          AM12345|           1200|
# |    3| Joe Jones|               5000|          AM12345|           2400|
# |    5|John Jones|               3000|          AM12345|           null|
# |    9|John Jones|               5000|          AM12345|           3000|
# |   12|John Jones|               6000|          AM12345|           5000|
# |    1|John Jones|               2400|         ENG12345|           null|
# |    4|John Jones|               1200|         ENG12345|           2400|
# |    8|John Jones|               2300|         ENG12345|           1200|
#  ----- ---------- ------------------- ----------------- --------------- 

CodePudding user response:

You can do something like this

from pyspark.sql.window import *

custom_window=Window().partitionBy(["userid"]).rowsBetween(-1, -1).orderBy(["month"])

df = df.withColumn("previous_salary", max(df.amount).over(custom_window))

  • Related