I have a pyspark-dataframe where one column is a list with a given order:
df_in = spark.createDataFrame(
[
(1,['A', 'B', 'A', 'F', 'C', 'D']),
(2,['F', 'C', 'B', 'X', 'A', 'D']),
(3,['L', 'A', 'B', 'M', 'C'])])
I want to specify two elements e_1 and e_2 (e.g. 'A' and 'C'), such that the resulting column contains of a list with all the elements that are after the first occurence of either e_1 or e_2 and before the last occurence of either e_1 or e_2. Hence, the resulting df schould look like this:
df_out = spark.createDataFrame(
[
(1,['A', 'B', 'A', 'F', 'C']),
(2,['C', 'B', 'X', 'A']),
(3,['A', 'B', 'M', 'C'])])
How do I achieve this? Thanks in advance!
Best regards
CodePudding user response:
One way to do is:
- Join array into a string with elements separated by
,
. - Use regex to extract required sub-string.
- Split string by
,
back into original array. - Cast, if required.
df = spark.createDataFrame(data=[(1,["X A", "X B", "X A", "X F", "X C", "X D"]),(2,["X F", "X C", "X B", "X X", "X A", "X D"]),(3,["X L", "X A", "X B", "X M", "X C"])], schema=["id","arr"])
match_list = ["X A", "X C", "X D"]
match_any = "|".join([w for w in match_list])
regex = rf"((?:{match_any}).*(?:{match_any}))"
df = df.withColumn("arr", F.concat_ws(",", "arr")) \
.withColumn("arr", F.regexp_extract("arr", regex, 1)) \
.withColumn("arr", F.split("arr", ","))
Output:
--- ------------------------------
|id |arr |
--- ------------------------------
|1 |[X A, X B, X A, X F, X C, X D]|
|2 |[X C, X B, X X, X A, X D] |
|3 |[X A, X B, X M, X C] |
--- ------------------------------
CodePudding user response:
This is an array.
Find the position of the elements A and c using array_position
That will give you an array. Please sort the array using sort_array
Now slice
resultant column starting from the lowest position from the array_position above. subtract the lowest position element from maximum and add on 1 to get the length to pass in the slice.
Code below. Left in intermediate column x for you to follow through.
df_in.withColumn('x', sort_array(array(*[array_position(col('_2'),x) for x in ['A','C']]))).withColumn('y', slice(col('_2'),col('x')[0],col('x')[1]-col('x')[0] 1)).show()
--- ------------------ ------ ---------------
| _1| _2| x| y|
--- ------------------ ------ ---------------
| 1|[A, B, A, F, C, D]|[1, 5]|[A, B, A, F, C]|
| 2|[F, C, B, X, A, D]|[2, 5]| [C, B, X, A]|
| 3| [L, A, B, M, C]|[2, 5]| [A, B, M, C]|
--- ------------------ ------ ---------------
CodePudding user response:
- Define a function that does what you want with a single element of the dataframe.
- Extend it to a "user defined function" (UDF) that is able to operate on a column of the dataframe.
In this example, the function might look like
def trim(start_stop, l):
first = None
for i,e in enumerate(l):
if e in start_stop:
if first is None:
first=i
last = i
return l[first:last 1] if first is not None else []
trim({'B','C'}, ['A', 'B', 'A', 'F', 'C', 'D'])
# ['B', 'A', 'F', 'C']
For the UDF, we have to decide whether we want start_stop
to be a fixed value or flexible.
from pyspark.sql import functions, types
# {'A','C'} as hard-coded value for start_stop
trimUDF = functions.udf(
lambda x: trim({'A','C'}, x),
types.ArrayType(types.StringType()))
# use as trimUDF(COLNAME)
# flexible start_stop
def trimUDF(start_stop):
return functions.udf(
lambda l: trim(start_stop,l),
types.ArrayType(types.StringType()))
# use as trimUDF(start_stop)(COLNAME)
For your example and the second version of trimUDF
, we obtain:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
(1,['A', 'B', 'A', 'F', 'C', 'D']),
(2,['F', 'C', 'B', 'X', 'A', 'D']),
(3,['L', 'A', 'B', 'M', 'C'])])
df.show()
df = df.withColumn("_2",trimUDF({'A','C'})("_2"))
df.show()
This code results in the following output.
--- ------------------
| _1| _2|
--- ------------------
| 1|[A, B, A, F, C, D]|
| 2|[F, C, B, X, A, D]|
| 3| [L, A, B, M, C]|
--- ------------------
--- ---------------
| _1| _2|
--- ---------------
| 1|[A, B, A, F, C]|
| 2| [C, B, X, A]|
| 3| [A, B, M, C]|
--- ---------------
For the record, here is the complete code.
# python3 -m venv venv
# . venv/bin/activate
# pip install wheel pyspark[sql]
from pyspark.sql import SparkSession, functions, types
def trim(start_stop, l):
first = None
for i,e in enumerate(l):
if e in start_stop:
if first is None:
first=i
last = i
return l[first:last 1] if first is not None else []
#trimUDF = functions.udf(
# lambda x: trim({'A','C'}, x),
# types.ArrayType(types.StringType()))
def trimUDF(start_stop):
return functions.udf(
lambda l: trim(start_stop,l),
types.ArrayType(types.StringType()))
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
(1,['A', 'B', 'A', 'F', 'C', 'D']),
(2,['F', 'C', 'B', 'X', 'A', 'D']),
(3,['L', 'A', 'B', 'M', 'C'])])
df.show()
df = df.withColumn("_2",trimUDF({'A','C'})("_2"))
df.show()