Home > Back-end >  Pyspark filter or split based on row value
Pyspark filter or split based on row value

Time:01-11

I'm dealing with some data in pyspark. We have the issue, that metadata and actual data are mixed. This means we have Strings which are interrupted by a "STOP" string. The number of Strings between the "STOP" is variable and we would like to filter out short occurrences.

An example dataframe where we have ints instead of Strings and 0 is the stop signal is below:

df = spark.createDataFrame([*[(1,),(0,),(3,),(0,),(4,),(4,),(5,),(0,)]])

My goal would now be to have a filter function, where I can say how many elements between two stop signals need to be, in order for the data to be kept. E.g. if min_length was two, we would end up with the dataframe:

df = spark.createDataFrame([(4,),(4,),(5,),(0,)]])

My idea was to create a seperate column and create a group in there:

df.select("_1", F.when(df["_1"]==2, 0).otherwise(get_counter())).show()

The get_counter function should count how many times we've already seen "Stop" (or 0 in the example df). Due to the distributed nature of Spark that does not work though.

Is it somehow possible to easily achive this by filtering? Or is it maybe possible to split the dataframes, everytime "STOP" occurs? I could then delete to short dataframes and merge them again.

Preferably this would be solved in pyspark or sql-spark. But if someone knows how to do this with the spark-shell, I'd also be curious :)

CodePudding user response:

Spark sql implementation:

with t2 as (
  select 
    monotonically_increasing_id() as id
  , _1
  , case when _1 = 0 then 1 else 0 end as stop  
  from 
    t1 
)
, t3 as (
  select 
    * 
  , row_number() over (partition by stop order by id) as stop_seq  
  from 
    t2
)  
select * from t3 where stop_seq > 2
  • Related