I have the following scenario: A dataframe with payment informations like bill id number, client id, bill's value, date of payment, etc. In this DF the same client can have N bills with the same bill_id each. So, I want to:
- Make an in evaluation to check if some bill_id appears more than 1 time for the customer;
- If yes, I want to keep only the most recent record using timestamp criteria;
- If not, I want to select the unique record for that bill_id.
- The result will be stored in a new DF from a df.where clause
I tried the following code with no success:
df_clients_bills = df_clients_bills.where(
when(countDistinct(df_clients_bills.bill_id) > 1, max(df_clients_bills.payment_date)).otherwise(df_clients_bills)
)
I don't know if this is the best approach to solve the question. Any tip which can lead to the solution will be appreciated.
CodePudding user response:
You can use SQL to accomplish this:
import datetime
import pandas as pd
# Just a basic example
df_pd = pd.DataFrame({
"bill_id": [1, 1, 1],
"payment_date": [
datetime.datetime.utcnow() - datetime.timedelta(days=i) for i in range(3)],
"value": [1, 2, 3]
})
df_clients_bills = spark.createDataFrame(df_pd)
df_clients_bills.registerTempTable("df_clients_bills")
query = """
SELECT
df_clients_bills.bill_id,
df_clients_bills.payment_date,
df_clients_bills.value
FROM df_clients_bills
INNER JOIN (
SELECT bill_id, MAX(payment_date) AS max_payment_date
FROM df_clients_bills
GROUP BY bill_id
) AS bill_id_max_dates
ON
(df_clients_bills.bill_id = bill_id_max_dates.bill_id) AND
(df_clients_bills.payment_date = bill_id_max_dates.max_payment_date)
"""
result = spark.sql(query)
result.show(5)
CodePudding user response:
For your problem, you actually don't need to distinguish between distinct counts values, because the "take most recent bill" if there is several bills also works if there is only one bill.
So, what you can do is use a window function to rank bills having the same id from max timestamp to min timestamp and then select the first line of each group, as follows:
from pyspark.sql import functions as F
from pyspark.sql import Window
df_clients_bills = df_clients_bills \
.withColumn('row_number', F.row_number().over(Window.partitionBy('bill_id').orderBy(F.desc('payment_date')))) \
.filter(F.col('row_number') == 1) \
.drop('row_number')