Home > database >  Filter PySpark dataframe into a list of dataframes
Filter PySpark dataframe into a list of dataframes

Time:07-14

I have a PySpark dataframe and I want to filter based on unique values in some columns.

from pyspark.sql import SparkSession
spark_session = SparkSession.builder.enableHiveSupport().getOrCreate()

columns = ["language","users_count","apple"]
data = [("Java", 1, 0.0), ("Scala", 4, -4.0), ("Java", 1, 0.0)]

pyspark_df = spark_session.createDataFrame(data).toDF(*columns)

pandas_df = pd.DataFrame(data, columns=columns)

# Operation I want to replicate in PySpark:
column_list = ['language','users_count'] #these names and number of columns can be changed at runtime.
unique_dfs = [df for id, df in pandas_df.groupby(column_list
, as_index=False)]

Another approach that can be done is to create a column in PySpark df and put unique values (string ( language users_count ) and later filter on those unique values to get dfs.

CodePudding user response:

If you know exactly what data you need, you should do filter, because it is efficient in Spark.

from pyspark.sql import functions as F

df = pyspark_df.filter(
    (F.col('language') == 'Java') &
    (F.col('users_count') == 1)
)

If you REALLY need all the possible combinations of those columns as separate dataframes, you will have to run distinct (i.e. to-be-avoided shuffle) and inefficient collect

from pyspark.sql import functions as F

df = pyspark_df.select('language', 'users_count').distinct()
unique_dfs = []
for lang, uc in df.collect():
    unique_dfs.append(
        pyspark_df.filter(
            (F.col('language') == lang) &
            (F.col('users_count') == uc)
        )
    )

Results:

unique_dfs[0].show()
#  -------- ----------- ----- 
# |language|users_count|apple|
#  -------- ----------- ----- 
# |    Java|          1|  0.0|
# |    Java|          1|  0.0|
#  -------- ----------- ----- 
unique_dfs[1].show()
#  -------- ----------- ----- 
# |language|users_count|apple|
#  -------- ----------- ----- 
# |   Scala|          4| -4.0|
#  -------- ----------- ----- 

Note: Here you see that Java is indexed as 0, Scala as 1, but in reality it could be opposite, you don't have determinism there, as you don't know which executor will send his data first to the driver after driver asked for data when you used collect. So, what you asked, is probably not what you truly needed.

CodePudding user response:

Create rank using window function with partitioning on the columns (you want to group based on value of). Then iterate from 1 to df.count() and filter dataframe based on rank and store dataframes into list. I hope this helps!

  • Related