I have some data like this
data = [("1","1"), ("1","1"), ("1","1"), ("2","1"), ("2","1"), ("3","1"), ("3","1"), ("4","1"),]
df =spark.createDataFrame(data=data,schema=["id","imp"])
df.createOrReplaceTempView("df")
--- ---
| id|imp|
--- ---
| 1| 1|
| 1| 1|
| 1| 1|
| 2| 1|
| 2| 1|
| 3| 1|
| 3| 1|
| 4| 1|
--- ---
I want the count of IDs grouped by ID, it's running sum and total sum. This is the code I'm using
query = """
select id,
count(id) as count,
sum(count(id)) over (order by count(id) desc) as running_sum,
sum(count(id)) over () as total_sum
from df
group by id
order by count desc
"""
spark.sql(query).show()
--- ----- ----------- ---------
| id|count|running_sum|total_sum|
--- ----- ----------- ---------
| 1| 3| 3| 8|
| 2| 2| 7| 8|
| 3| 2| 7| 8|
| 4| 1| 8| 8|
--- ----- ----------- ---------
The problem is with the running_sum
column. For some reason it automatically groups the count 2 while summing and shows 7 for both ID 2 and 3.
This is the result I'm expecting
--- ----- ----------- ---------
| id|count|running_sum|total_sum|
--- ----- ----------- ---------
| 1| 3| 3| 8|
| 2| 2| 5| 8|
| 3| 2| 7| 8|
| 4| 1| 8| 8|
--- ----- ----------- ---------
CodePudding user response:
You should do the running sum in an outer query.
spark.sql('''
select *,
sum(cnt) over (order by id rows between unbounded preceding and current row) as run_sum,
sum(cnt) over (partition by '1') as tot_sum
from (
select id, count(id) as cnt
from data_tbl
group by id)
'''). \
show()
# --- --- ------- -------
# | id|cnt|run_sum|tot_sum|
# --- --- ------- -------
# | 1| 3| 3| 8|
# | 2| 2| 5| 8|
# | 3| 2| 7| 8|
# | 4| 1| 8| 8|
# --- --- ------- -------
Using dataframe API
data_sdf. \
groupBy('id'). \
agg(func.count('id').alias('cnt')). \
withColumn('run_sum',
func.sum('cnt').over(wd.partitionBy().orderBy('id').rowsBetween(-sys.maxsize, 0))
). \
withColumn('tot_sum', func.sum('cnt').over(wd.partitionBy())). \
show()
# --- --- ------- -------
# | id|cnt|run_sum|tot_sum|
# --- --- ------- -------
# | 1| 3| 3| 8|
# | 2| 2| 5| 8|
# | 3| 2| 7| 8|
# | 4| 1| 8| 8|
# --- --- ------- -------