Home > Software design >  PySpark - Truncate lists in a column according to first and last occurrences of specific elements
PySpark - Truncate lists in a column according to first and last occurrences of specific elements

Time:12-12

I have a pyspark-dataframe where one column is a list with a given order:

df_in = spark.createDataFrame(
[
(1,['A', 'B', 'A', 'F', 'C', 'D']),
(2,['F', 'C', 'B', 'X', 'A', 'D']),
(3,['L', 'A', 'B', 'M', 'C'])])

I want to specify two elements e_1 and e_2 (e.g. 'A' and 'C'), such that the resulting column contains of a list with all the elements that are after the first occurence of either e_1 or e_2 and before the last occurence of either e_1 or e_2. Hence, the resulting df schould look like this:

df_out = spark.createDataFrame(
[
(1,['A', 'B', 'A', 'F', 'C']),
(2,['C', 'B', 'X', 'A']),
(3,['A', 'B', 'M', 'C'])])

How do I achieve this? Thanks in advance!

Best regards

CodePudding user response:

One way to do is:

  • Join array into a string with elements separated by ,.
  • Use regex to extract required sub-string.
  • Split string by , back into original array.
  • Cast, if required.
df = spark.createDataFrame(data=[(1,["X A", "X B", "X A", "X F", "X C", "X D"]),(2,["X F", "X C", "X B", "X X", "X A", "X D"]),(3,["X L", "X A", "X B", "X M", "X C"])], schema=["id","arr"])

match_list = ["X A", "X C", "X D"]
match_any = "|".join([w for w in match_list])
regex = rf"((?:{match_any}).*(?:{match_any}))"

df = df.withColumn("arr", F.concat_ws(",", "arr")) \
       .withColumn("arr", F.regexp_extract("arr", regex, 1)) \
       .withColumn("arr", F.split("arr", ","))

Output:

 --- ------------------------------ 
|id |arr                           |
 --- ------------------------------ 
|1  |[X A, X B, X A, X F, X C, X D]|
|2  |[X C, X B, X X, X A, X D]     |
|3  |[X A, X B, X M, X C]          |
 --- ------------------------------ 

CodePudding user response:

This is an array. Find the position of the elements A and c using array_position That will give you an array. Please sort the array using sort_array Now slice resultant column starting from the lowest position from the array_position above. subtract the lowest position element from maximum and add on 1 to get the length to pass in the slice.

Code below. Left in intermediate column x for you to follow through.

df_in.withColumn('x', sort_array(array(*[array_position(col('_2'),x) for x in ['A','C']]))).withColumn('y', slice(col('_2'),col('x')[0],col('x')[1]-col('x')[0] 1)).show()


 --- ------------------ ------ --------------- 
| _1|                _2|     x|              y|
 --- ------------------ ------ --------------- 
|  1|[A, B, A, F, C, D]|[1, 5]|[A, B, A, F, C]|
|  2|[F, C, B, X, A, D]|[2, 5]|   [C, B, X, A]|
|  3|   [L, A, B, M, C]|[2, 5]|   [A, B, M, C]|
 --- ------------------ ------ --------------- 

CodePudding user response:

  1. Define a function that does what you want with a single element of the dataframe.
  2. Extend it to a "user defined function" (UDF) that is able to operate on a column of the dataframe.

In this example, the function might look like

def trim(start_stop, l):
    first = None
    for i,e in enumerate(l):
        if e in start_stop:
            if first is None:
                first=i
            last = i
    return l[first:last 1] if first is not None else []

trim({'B','C'}, ['A', 'B', 'A', 'F', 'C', 'D'])
# ['B', 'A', 'F', 'C']

For the UDF, we have to decide whether we want start_stop to be a fixed value or flexible.

from pyspark.sql import functions, types

# {'A','C'} as hard-coded value for start_stop
trimUDF = functions.udf(
    lambda x: trim({'A','C'}, x),
    types.ArrayType(types.StringType()))
# use as trimUDF(COLNAME)

# flexible start_stop
def trimUDF(start_stop):
    return functions.udf(
        lambda l: trim(start_stop,l),
        types.ArrayType(types.StringType()))
# use as trimUDF(start_stop)(COLNAME)

For your example and the second version of trimUDF, we obtain:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
    (1,['A', 'B', 'A', 'F', 'C', 'D']),
    (2,['F', 'C', 'B', 'X', 'A', 'D']),
    (3,['L', 'A', 'B', 'M', 'C'])])
df.show()
df = df.withColumn("_2",trimUDF({'A','C'})("_2"))
df.show()

This code results in the following output.

 --- ------------------                                                         
| _1|                _2|
 --- ------------------ 
|  1|[A, B, A, F, C, D]|
|  2|[F, C, B, X, A, D]|
|  3|   [L, A, B, M, C]|
 --- ------------------ 

 --- ---------------                                                            
| _1|             _2|
 --- --------------- 
|  1|[A, B, A, F, C]|
|  2|   [C, B, X, A]|
|  3|   [A, B, M, C]|
 --- --------------- 

For the record, here is the complete code.

# python3 -m venv venv
# . venv/bin/activate
# pip install wheel pyspark[sql]

from pyspark.sql import SparkSession, functions, types

def trim(start_stop, l):
    first = None 
    for i,e in enumerate(l):
        if e in start_stop:
            if first is None:
                first=i
            last = i
    return l[first:last 1] if first is not None else []

#trimUDF = functions.udf(
#    lambda x: trim({'A','C'}, x),
#    types.ArrayType(types.StringType()))

def trimUDF(start_stop):
    return functions.udf(
        lambda l: trim(start_stop,l),
        types.ArrayType(types.StringType()))

spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
    (1,['A', 'B', 'A', 'F', 'C', 'D']),
    (2,['F', 'C', 'B', 'X', 'A', 'D']),
    (3,['L', 'A', 'B', 'M', 'C'])])
df.show()
df = df.withColumn("_2",trimUDF({'A','C'})("_2"))
df.show()
  • Related