I'm trying to build a replacement for the .shift()
function of Python. I'm pretty close, but need the final touch to make everything work. This would be implementing a right form of GroupBy
.
I have a dataframe like this one:
month | userid | amount | counterparty_iban |
---|---|---|---|
1 | John Jones | 2400 | ENG12345 |
4 | John Jones | 1200 | ENG12345 |
8 | John Jones | 2300 | ENG12345 |
5 | John Jones | 3000 | AM12345 |
9 | John Jones | 5000 | AM12345 |
12 | John Jones | 6000 | AM12345 |
1 | Joe Jones | 1200 | AM12345 |
2 | Joe Jones | 2400 | AM12345 |
3 | Joe Jones | 5000 | AM12345 |
I need to change the dataframe above to the format of the table below:
month | userid | amount | counterparty_iban | previous_salary |
---|---|---|---|---|
1 | John Jones | 2400 | ENG12345 | null |
4 | John Jones | 1200 | ENG12345 | 2400 |
8 | John Jones | 2300 | ENG12345 | 1200 |
5 | John Jones | 3000 | AM12345 | null |
9 | John Jones | 5000 | AM12345 | 3000 |
12 | John Jones | 6000 | AM12345 | 5000 |
1 | Joe Jones | 1200 | AM12345 | null |
2 | Joe Jones | 2400 | AM12345 | 1200 |
3 | Joe Jones | 5000 | AM12345 | 2400 |
Here is the code to create the input dataframe:
columns = ["month", "userid", 'exactoriginalamount', 'counterparty_iban']
data = [("1", "John Jones", "2400", 'ENG12345'),
("4", "John Jones", "1200", 'ENG12345'),
("8", "John Jones", "2300", 'ENG12345'),
("5", "John Jones", "3000", 'AM12345'),
("9", "John Jones", "5000", 'AM12345'),
("12", "John Jones", "6000", 'AM12345'),
("1", "Joe Jones", "1200", 'AM12345'),
("2", "Joe Jones", "2400", 'AM12345'),
("3", "Joe Jones", "5000", 'AM12345')]
df = spark.createDataFrame(data=data, schema=columns)
I've been trying numerous applications with the following code:
w = Window().partitionBy().orderBy(F.col('userid'))
df = df.withColumn('previous_salary', F.lag('exactoriginalamount', 1).over(w))
However, I somehow need to groupby "userid" and "counterparty_iban" so that the different "previous_salary" column displayed the right data.
CodePudding user response:
You need to properly describe partitions. In the example output, I see that you want to make partitions (windows) based on "userid" and "counterparty_iban". The function lag
will be run in these partitions separately. In the script below, I also used cast("long")
for the "month" column, as using your code the column "month" is originally created of type string (ordering by string column would return different sort order than ordering by number).
w = Window.partitionBy("userid", "counterparty_iban").orderBy(F.col("month").cast("long"))
df = df.withColumn("previous_salary", F.lag("exactoriginalamount").over(w))
df.show()
# ----- ---------- ------------------- ----------------- ---------------
# |month| userid|exactoriginalamount|counterparty_iban|previous_salary|
# ----- ---------- ------------------- ----------------- ---------------
# | 1| Joe Jones| 1200| AM12345| null|
# | 2| Joe Jones| 2400| AM12345| 1200|
# | 3| Joe Jones| 5000| AM12345| 2400|
# | 5|John Jones| 3000| AM12345| null|
# | 9|John Jones| 5000| AM12345| 3000|
# | 12|John Jones| 6000| AM12345| 5000|
# | 1|John Jones| 2400| ENG12345| null|
# | 4|John Jones| 1200| ENG12345| 2400|
# | 8|John Jones| 2300| ENG12345| 1200|
# ----- ---------- ------------------- ----------------- ---------------
CodePudding user response:
You can do something like this
from pyspark.sql.window import *
custom_window=Window().partitionBy(["userid"]).rowsBetween(-1, -1).orderBy(["month"])
df = df.withColumn("previous_salary", max(df.amount).over(custom_window))