The goal is to use a pandas user-defined function as a window function in pyspark
. Here is a minimal example.
df
is a pandas
DataFrame and a spark
table:
import pandas as pd
from pyspark.sql import SparkSession
df = pd.DataFrame(
{'x': [1, 1, 2, 2, 2, 3, 3],
'y': [1, 2, 3, 4, 5, 6, 7]})
spark = SparkSession.builder.getOrCreate()
spark.createDataFrame(df).createOrReplaceTempView('df')
Here is df
as a spark
table
In [10]: spark.sql('SELECT * FROM df').show()
--- ---
| x| y|
--- ---
| 1| 1|
| 1| 2|
| 2| 3|
| 2| 4|
| 2| 5|
| 3| 6|
| 3| 7|
--- ---
The minimal example is to implement a cumulative sum of y partitioned by x. Without any pandas user-defined function that looks like:
dx = spark.sql(f"""
SELECT x, y,
SUM(y) OVER (PARTITION BY x ORDER BY y) AS ysum
FROM df
ORDER BY x""").toPandas()
where dx
is then
In [2]: dx
Out[2]:
x y ysum
0 1 1 1
1 1 2 3
2 2 3 3
3 2 4 7
4 2 5 12
5 3 6 6
6 3 7 13
And a non-working attempt to do the same with pandas_udf
is
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
@pandas_udf(returnType=DoubleType())
def func(x: pd.Series) -> pd.Series:
return x.cumsum()
spark.udf.register('func', func)
dx = spark.sql(f"""
SELECT x, y,
func(y) OVER (PARTITION BY x ORDER BY y) AS ysum
FROM df
ORDER BY x""").toPandas()
which returns this error
AnalysisException: Expression 'func(y#1L)' not supported within a window function.;
...
UPDATE Based on answer by wwnde, solution was
def pdf_cumsum(pdf):
pdf['ysum'] = pdf['y'].cumsum()
return pdf
dx = sdf.groupby('x').applyInPandas(pdf_cumsum, schema='x long, y long, ysum long').toPandas()
CodePudding user response:
use mapInPandas
from Map Pandas Function API
sch =df.withColumn('ysum',lit(3)).schema
def cumsum_pdf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
for pdf in iterator:
yield pdf.assign(ysum=pdf.groupby('x')['y'].cumsum())
df.mapInPandas(cumsum_pdf, schema=sch).show()
Outcome
--- --- ----
| x| y|ysum|
--- --- ----
| 1| 1| 1|
| 1| 2| 3|
| 2| 3| 3|
| 2| 4| 7|
| 2| 5| 12|
| 3| 6| 6|
| 3| 7| 13|
--- --- ----