I have a database of patients and their results. Below are demo dataframes:
import pandas as pd
import numpy as np
from scipy.stats import linregress
data = [[1 , '20210201', 4567, 40],
[1 , '20210604', 4567, 55],
[1 , '20200405', 2574, 42],
[1 , '20210602', 2574, 55],
[2 , '20210201', 4567, 25],
[2 , '20210604', 4567, 32],
[2 , '20200405', 2574, 70],
[2 , '20210602', 2574, 46]]
df = pd.DataFrame(data, columns=['id', 'date', 'test_id', 'result'])
df.date = pd.to_datetime(df.date, format='%Y%m%d') # format date field
df
id date test_id result
0 1 2021-02-01 4567 40
1 1 2021-06-04 4567 55
2 1 2020-04-05 2574 42
3 1 2021-06-02 2574 55
4 2 2021-02-01 4567 25
5 2 2021-06-04 4567 32
6 2 2020-04-05 2574 70
7 2 2021-06-02 2574 46
data = [[1 , '20220101'],
[2 , '20220102']]
customers = pd.DataFrame(data, columns=['id', 'start_date'])
customers.start_date = pd.to_datetime(customers.start_date, format='%Y%m%d') # format date field
print(customers)
id start_date
0 1 2022-01-01
1 2 2022-01-02
And the following function that gets a customer and its initial date and returns aggregated results regarding each test in a specific time period before the initial date:
def patient_agg_results(df, patient_ID, X, Y, firstAF):
result = pd.DataFrame()
X_date = firstAF - pd.DateOffset(months=X)
Y_date = firstAF - pd.DateOffset(months=X Y)
# get results of specific patient within the timeframe
patient_results = df[(df['id'] == patient_ID) & (df['date'] < X_date) & (df['date'] > Y_date)] # ***
if (len(patient_results) > 0 ):
# Calculate mean
curr_result = pd.DataFrame(patient_results.groupby('test_id').mean()['result'])
curr_result = curr_result.set_index(curr_result.index.astype(str) '_mean')
result = pd.concat([result,curr_result])
# Calculate newest result
curr_result = pd.DataFrame(patient_results.groupby('test_id').max()['result'])
curr_result = curr_result.set_index(curr_result.index.astype(str) '_new')
result = pd.concat([result,curr_result])
# Calculate oldest result
curr_result = pd.DataFrame(patient_results.groupby('test_id').min()['result'])
curr_result = curr_result.set_index(curr_result.index.astype(str) '_old')
result = pd.concat([result,curr_result])
# Calculate STD
curr_result = pd.DataFrame(patient_results.groupby('test_id').std()['result'])
curr_result = curr_result.set_index(curr_result.index.astype(str) '_std')
result = pd.concat([result,curr_result])
# Calculate slope
patient_results['int_date'] = pd.to_datetime(patient_results['date']).astype(np.int64) # create integer date
curr_result = pd.DataFrame(patient_results.groupby('test_id')['result', 'int_date'].apply(lambda v: linregress(v.int_date, v.result)[0]))
curr_result.columns = ['result']
curr_result = curr_result.set_index(curr_result.index.astype(str) '_slope')
result = pd.concat([result,curr_result])
result['id'] = patient_ID
return result.to_dict()
I use the function like that:
customers['lab_results'] = customers.apply(lambda row: patient_agg_results(df,row['id'],12,12,row['start_date']),axis=1)
The problem is that my original datasets include about a million patients and a few million results which takes this code to run for a few days. The most time consuming line is the filtering line (comment: ***)
Any idea how to make it more time efficient?
CodePudding user response:
This is a great question. Usually when the data is all in memory and then things run for days it is because they have run into an execution combinatorial nightmare. Your pointer to the selection line as the biggest time consumer makes total sense – that this is a CPU-bound process.
If you assume that the df
has 3 million rows, then in that filter line – since it is part of an apply
function – there are 9 million comparisons that then happens 1 million times. That’s a big number that could lead to days of execution. Spark is excellent at handling large data but I suspect that you will have this same CPU-bound combinatorial nightmare in Spark.
You may reap huge benefits by approaching this from the other direction – a groupby
from df
.
df.grouby('id').apply(…)
The apply
function from groupby
will receive a full sub-frame of all the columns and rows for each id
. With this approach you will be doing a date selection on an average of 3 rows times 2 and not 3 million times 2 for each id
. And the repeated 3 million-row id
selection completely goes away. It is handled once during the initial groupby
operation.
Within the df.groupby.apply
function you could lookup what you need from the customer
data frame by using customer.at[id, 'start_date']
. (After you set the customer
data frame’s index to id
- which would speed things up)
There may be some savings by consolidating your set of groupby
calls in your function to instead use one groupby.agg()
instead. Though didn’t get a chance to fully digest that to come up with a more solidified suggestion.
Once you get the combinations under control you can go here for some great ways to improve performance: https://pandas.pydata.org/docs/user_guide/enhancingperf.html
Man, I love this kind of problem! I wish I had more time.
Here is some starter code to communicate this idea
def jch_agg_result(gf, X, Y):
firstAF = customers.at[gf['id'].iat[0], 'start_date']
X_date = firstAF - pd.DateOffset(months=X)
Y_date = firstAF - pd.DateOffset(months=X Y)
patient_results = gf[(gf['date'] < X_date) & (gf['date'] > Y_date)]
if (len(patient_results) > 0 ):
tid = gf['test_id'].iat[0]
m = {f'{tid}_mean': gf['result'].mean(), f'{tid}_max': gf['result'].max()}
return m
return np.nan
Then:
df.groupby(['id','test_id']).apply(jch_agg_result,12,12).dropna()
id test_id
1 2574 {'2574_mean': 48.5, '2574_max': 55}
2 2574 {'2574_mean': 58.0, '2574_max': 70}
You'd then just need to merge back in the applicable date to get to your desired result. And - not sure if I got the inner mean, max, etc right but it is there so you can modify to meet your needs.
CodePudding user response:
PySpark should be able to help you out. There may be other, faster solutions, but this will be quick to both run and implement. Most of the functions are generally similar between the PySpark and Pandas, and in my experience on large datasets with simple operations like this, PySpark should help you out.