Home > Blockchain >  Filter list of rows based on a column value in PySpark
Filter list of rows based on a column value in PySpark

Time:01-04

I have a list of rows after using collect. How can I get the "num_samples" value where sample_label == 0? That is to say, how can I filter list of rows based on a column value?

[Row(sample_label=1, num_samples=14398),
 Row(sample_label=0, num_samples=12500),
 Row(sample_label=2, num_samples=98230]

CodePudding user response:

Filter dataframe rows on the needed condition before collecting them:

df.filter(df.sample_label == 0).collect()

CodePudding user response:

One of the possible answers to your question is list comprehension:

data = [Row(sample_label=1, num_samples=14398),
 Row(sample_label=0, num_samples=12500),
 Row(sample_label=2, num_samples=98230)]

filtered_data = [row.num_samples for row in data if row.sample_label == 0]

EDIT: as @RomanPerekhrest is saying, filter before you collect. Why? As per documentation:

This method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver’s memory.

Collect the data only when you are absolutely sure that you really need the collected data in such a format. If you do, heavily filter the data to lower the chances of a driver node failing with an OOM exception.

  • Related