Imagine having the following spark dataframe
--- ---- ----
| Id|A_rk|B_rk|
--- ---- ----
| a| 5| 4|
| b| 7| 7|
| c| 5| 4|
| d| 1| 0|
--- ---- ----
I want to create a column called Pair
that takes the value of B_rk
if two rows have the same values for A_rk
and B_rk
and the value of 0 if there is no match. The result would be:
--- ---- ---- ----
| Id|A_rk|B_rk|Pair|
--- ---- ---- ----
| a| 5| 4| 4|
| b| 7| 7| 0|
| c| 5| 4| 4|
| d| 1| 0| 0|
--- ---- ---- ----
I had a successful attempt with pandas using for loops. But I want to use spark for better performance.
CodePudding user response:
1.4 million rows is still fine with pandas
.
data = [(randint(0, 1000000), randint(0, 100000)) for _ in range(1400000)]
df = pd.DataFrame(data, columns=['A_rk', 'B_rk'])
df['cnt'] = df.groupby(['A_rk', 'B_rk']).transform('size')
df.loc[df.cnt > 1, 'Pair'] = df.B_rk. # or df.cnt == 2 if you only count pair (2 rows exact).
df['Pair'] = df.Pair.fillna(0).astype(int)
I just ran this with 1.4 million rows and ran in less than a second.
389 ms ± 4.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Pyspark solution:
from pyspark.sql import Window
from pyspark.sql import functions as F
w = Window.partitionBy('A_rk', 'B_rk')
df = (df.withColumn('cnt', F.count('A_rk').over(w))
.withColumn('Pair', F.when(F.col('cnt') > 1, F.col('B_rk')).otherwise(F.lit(0))))