I have min 12 periods in list, these are not fixed might have more based on the selected product. Also, I have a dict which has period as key and products as list of values.
{
"20191": ["prod1","prod2","prod3"],
"20192": ["prod2","prod3"],
"20193": ["prod2"]
}
II need to select the data based on period and compute the sum of the respective period, amount.
sample_data
period | product | amount |
---|---|---|
20191 | prod1 | 30 |
20192 | prod1 | 30 |
20191 | prod2 | 20 |
20191 | prod3 | 60 |
20193 | prod1 | 30 |
20193 | prod2 | 30 |
output
period | product | amount |
---|---|---|
20191 | prod1 | 110 |
20192 | 0 | |
20193 | prod3 | 30 |
Basically, for each of the period, select only those products from the dict, and sum it up.
My code which is taking lot of time:
list_series = []
df = spark.read.csv(path,header=True)
periods = df.select("period").distinct().collect()
for period in periods:
df1 = df.filter(f"period = {period}").filter(F.col("product").isin(dict["period"]).groupBy("priod","product").agg(F.sum("Amount").alias("Amount")
list_series.append(df1)
dataframe = reduce(DataFrame.unionAll,list_series)
Is there any way, I can modify and increase the performance?
CodePudding user response:
Solution
Flatten the input dictionary into list of tuples then create a new spark dataframe called filters
, then join
this dataframe with the original one by columns periods
and product
, then groupby period
and aggregate amount
using sum
d = [(i, k) for k, v in dct.items() for i in v]
filters = spark.createDataFrame(d, schema=['product', 'period'])
(
df
.join(filters, on=['period', 'product'], how='right')
.groupby('period')
.agg(F.sum('amount').alias('amount'))
.fillna(0)
)
Result
------ ------
|period|amount|
------ ------
| 20191| 110|
| 20192| 0|
| 20193| 30|
------ ------
CodePudding user response:
With the following input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('20191', 'prod1', 30),
('20192', 'prod1', 30),
('20191', 'prod2', 20),
('20191', 'prod3', 60),
('20193', 'prod1', 30),
('20193', 'prod2', 30)],
['period', 'product', 'amount'])
periods = ["20191", "20192", "20193"]
period_products = {
"20191": ["prod1","prod2","prod3"],
"20192": ["prod2","prod3"],
"20193": ["prod2"]
}
To make your script more performant, you will need to remove steps which create several dfs FROM ONE and then union them all back together. Do it in one dataframe without splitting.
You can create the filter condition in Python, supply it to the filter function and then aggregate.
conds = [f"((period = '{p}') and (product ='{prod}'))" for p in periods for prod in period_products[p]]
cond = ' or '.join(conds)
df_periods = spark.createDataFrame([(periods,)]).select(
F.explode('_1').alias('period')
)
df = (df_periods
.join(df.filter(cond), 'period', 'left')
.groupBy('period', 'product')
.agg(F.sum('amount').alias('amount'))
)
df.show()
# ------ ------- ------
# |period|product|amount|
# ------ ------- ------
# | 20191| prod3| 60|
# | 20191| prod2| 20|
# | 20191| prod1| 30|
# | 20193| prod2| 30|
# | 20192| null| null|
# ------ ------- ------