I have a SQL query like below:
select col4, col5 from TableA where col1 = 'x'
intersect
select col4, col5 from TableA where col1 = 'y'
intersect
select col4, col5 from TableA where col1 = 'z'
How can I convert this SQL to PySpark equivalent? I can create 3 DF and then do intersect like below:
df1 ==> select col4, col5 from TableA where col1 = 'x'
df2 ==> select col4, col5 from TableA where col1 = 'y'
df3 ==> select col4, col5 from TableA where col1 = 'z'
df_result = df1.intersect(df2)
df_result = df_result.intersect(df3)
But I feel that's not good approach to follow if I had more intersect
queries.
Also, let's say [x,y,z] is dynamic, means it can be like [x,y,z,a,b,.....]
Any suggestion?
CodePudding user response:
If you wanted to do several consecutive intersect
, there's reduce
available. Put all your dfs in one list and you will do intersect consecutively:
from functools import reduce
dfs = [df1, df2,...]
df = reduce(lambda a, b: a.intersect(b), dfs)
But it would be inefficient in your case.
Since all the data comes from the same dataframe, I would suggest a rework. Instead of dividing df and then rejoining using intersect
, do an aggregation and filtering.
Script (Spark 3.1):
vals = ['x', 'y', 'z']
arr = F.array([F.lit(v) for v in vals])
df = df.groupBy('col4', 'col5').agg(F.collect_set('col1').alias('set'))
df = df.filter(F.forall(arr, lambda x: F.array_contains('set', x)))
df = df.drop('set')
Test:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, 11, 'y'),
(1, 11, 'y'),
(1, 11, 'x'),
(2, 22, 'x'),
(1, 11, 'z'),
(4, 44, 'z'),
(1, 11, 'M')],
['col4', 'col5', 'col1'])
vals = ['x', 'y', 'z']
arr = F.array([F.lit(v) for v in vals])
df = df.groupBy('col4', 'col5').agg(F.collect_set('col1').alias('set'))
df = df.filter(F.forall(arr, lambda x: F.array_contains('set', x)))
df = df.drop('set')
df.show()
# ---- ----
# |col4|col5|
# ---- ----
# | 1| 11|
# ---- ----