I have a DataFrame and a list of borders:
test = spark.createDataFrame(
[
(1,),
(2,),
(234,),
(0,),
(6,),
(7,),
(35,),
(46,),
(8,),
],
"Population int",
)
border_list = [0, 1.5, 7, 41, 235]
And I want to add two new columns to the DataFrame ("LowerBorder", "UpperBorder") for the "Population" column.
And when I tried it to do with just python list and functions, it worked:
lower = lambda x: max([i for i in border_list if x >= i])
upper = lambda x: min([i for i in border_list if x < i])
list_value = [1, 2, 234, 0, 6, 7, 35, 46, 8]
for i in list_value:
print(lower(i), upper(I))
# Output:
# low high
0 1.5
1.5 7
41 235
0 1.5
1.5 7
7 41
7 41
41 235
7 41
But when I have tried to convert it to work with a column, it hasn't:
from pyspark.sql.types import FloatType
lower_border = F.udf(lambda x: max([i for i in border_list if x >= i]), FloatType())
upper_border = F.udf(lambda x: min([i for i in border_list if x < i]), FloatType())
test.withColumn("LowBorder", lower_border("Population")) \
.withColumn("UpBorder", upper_border("Population"))
display(test) # no changes in test Dataframe
If I try to add columns through select, it also doesn't work as desired:
display(test.select(lower_border("Population").alias('low'), upper_border("Population").alias('high')))
# Output:
low high
-----------
null 1.5
1.5 null
null null
null 1.5
1.5 null
null null
null null
null null
null null
Expected output for test DataFrame is:
Population | LowBorder | UpBorder
---------------------------------
1 0 1.5
2 1.5 7
234 41 235
0 0 1.5
6 1.5 7
7 7 41
35 7 41
46 41 235
8 7 41
CodePudding user response:
You can create an array out of your border_list
, then filter
it and select either minimum or maximum.
from pyspark.sql import functions as F
test = spark.createDataFrame([(1,), (2,), (234,), (0,), (6,), (7,), (35,), (46,), (8,)], "Population int")
border_list = [0, 1.5, 7, 41, 235]
arr = F.array_sort(F.array([F.lit(x) for x in border_list]))
test = test.select(
'Population',
F.element_at(F.filter(arr, lambda x: x <= F.col('Population')), -1).alias('LowBorder'),
F.element_at(F.filter(arr, lambda x: x > F.col('Population')), 1).alias('UpBorder'),
)
test.show(truncate=0)
# ---------- --------- --------
# |Population|LowBorder|UpBorder|
# ---------- --------- --------
# |1 |0.0 |1.5 |
# |2 |1.5 |7.0 |
# |234 |41.0 |235.0 |
# |0 |0.0 |1.5 |
# |6 |1.5 |7.0 |
# |7 |7.0 |41.0 |
# |35 |7.0 |41.0 |
# |46 |41.0 |235.0 |
# |8 |7.0 |41.0 |
# ---------- --------- --------