In the below dataframe, there are several apartments
with different job
's:
--- --------- ------
|id |apartment|job |
--- --------- ------
|1 |Ap1 |dev |
|2 |Ap1 |anyl |
|3 |Ap2 |dev |
|4 |Ap2 |anyl |
|5 |Ap2 |anyl |
|6 |Ap2 |dev |
|7 |Ap2 |dev |
|8 |Ap2 |dev |
|9 |Ap3 |anyl |
|10 |Ap3 |dev |
|11 |Ap3 |dev |
--- --------- ------
For each apartment, the number of rows with job='dev'
should be equal to the number of rows with job='anyl'
(like for Ap1). How to delete the redundant rows with 'dev'
in all the apartments?
The expected result:
--- --------- ------
|id |apartment|job |
--- --------- ------
|1 |Ap1 |dev |
|2 |Ap1 |anyl |
|3 |Ap2 |dev |
|4 |Ap2 |anyl |
|5 |Ap2 |anyl |
|6 |Ap2 |dev |
|9 |Ap3 |anyl |
|10 |Ap3 |dev |
--- --------- ------
I guess I should use Window functions to deal with that, but I couldn't figure it out.
CodePudding user response:
I think you first need to find out how many 'anyl' do you have for every 'apartment' and then use it to delete all the excess 'dev'. So, first, aggregation, then join
and then window function row_number
before you can filter out what you don't need.
Setup:
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[(1, 'Ap1', 'dev'),
(2, 'Ap1', 'anyl'),
(3, 'Ap2', 'dev'),
(4, 'Ap2', 'anyl'),
(5, 'Ap2', 'anyl'),
(6, 'Ap2', 'dev'),
(7, 'Ap2', 'dev'),
(8, 'Ap2', 'dev'),
(9, 'Ap3', 'anyl'),
(10, 'Ap3', 'dev'),
(11, 'Ap3', 'dev')],
['id', 'apartment', 'job']
)
Script:
df_grp = df.filter(F.col('job') == 'anyl').groupBy('apartment').count()
df = df.join(df_grp, 'apartment', 'left')
w = W.partitionBy('apartment', 'job').orderBy('id')
df = df.withColumn('_rn', F.row_number().over(w))
df = df.filter('_rn <= count')
df = df.select('id', 'apartment', 'job')
df.show()
# --- --------- ----
# | id|apartment| job|
# --- --------- ----
# | 2| Ap1|anyl|
# | 1| Ap1| dev|
# | 4| Ap2|anyl|
# | 5| Ap2|anyl|
# | 3| Ap2| dev|
# | 6| Ap2| dev|
# | 9| Ap3|anyl|
# | 10| Ap3| dev|
# --- --------- ----
CodePudding user response:
Using a left semijoin instead of groupBy
filter
combo suggested by @ZygD might be more efficient:
>>> from pyspark.sql import Window
>>> from pyspark.sql.functions import *
>>> df1 = df.withColumn('rn', row_number().over(Window.partitionBy('apartment', 'job').orderBy('id')))
>>> df2 = df1.join(df1.alias('dfa').where("job='anyl'"),(df1.apartment==dfa.apartment)&(df1.rn==dfa.rn),'leftsemi')
>>> df2.show(truncate=False)
--- --------- ---- ---
|id |apartment|job |rn |
--- --------- ---- ---
|1 |Ap1 |dev |1 |
|2 |Ap1 |anyl|1 |
|3 |Ap2 |dev |1 |
|4 |Ap2 |anyl|1 |
|5 |Ap2 |anyl|2 |
|6 |Ap2 |dev |2 |
|9 |Ap3 |anyl|1 |
|10 |Ap3 |dev |1 |
--- --------- ---- ---