Home > database >  Pyspark get value closest to a fixed parameter in a given column within a window function
Pyspark get value closest to a fixed parameter in a given column within a window function

Time:11-29

Been bashing me head against this problem now and would like any insights!

uid date-perf other-data control-val
1 2022-10-10 500 100
1 2022-11-10 550 180
2 2022-10-10 400 180
2 20200-11-10 580 210

For the above data set I want to find which control-val is the closest to a given value = 200 within a given id, and then save the control val to a new col like so

uid date-perf other-data control-val closest-to-200
1 2022-10-10 500 100 180
1 2022-11-10 550 180 180
2 2022-10-10 400 180 210
2 20200-11-10 580 210 210

Ive tried using a window function and then a UDF to get absolute value to find closest distance. in the blkow function i send in the column i want to check in this case control-val

def closest(cols):
    W = Window.partitionBy("uuid").orderBy("date-perf")
    return F.array_sort(F.transform(
        F.collect_list(F.struct("other-data")).over(W),
        lambda x: F.struct(
            F.abs(200 - F.col(cols)).alias("diff"),
            x["other-data"].alias("other-data),
        )
    ))[0]["other-data"].alias("closest-to-200")

and then call this like so in pyspark

df = df.select("*",closest(F.col("control-val")))

but i get the following error: TypeError: Column is not iterable

Any ideas on how to achieve this?

CodePudding user response:

Your function references other-data yet your question doesn't mention it as being important. I ignored it for this answer and instead focused on returning the value of control-val which has the smallest absolute difference to the reference value of 200:

from pyspark.sql import functions as F, Window

df = spark.createDataFrame(
    [
        (1, "2022-10-10", 500, 100),
        (1, "2022-11-10", 550, 180),
        (2, "2022-10-10", 400, 180),
        (2, "2022-11-10", 580, 210),
    ],
    ["uid", "date-perf", "other-data", "control-val"],
)

diff = F.abs(F.col('control-val') - F.lit(200))
w = Window.partitionBy('uid').orderBy(diff)
df.withColumn("closest-to-200", F.first('control-val').over(w)).show()
 --- ---------- ---------- ----------- -------------- 
|uid| date-perf|other-data|control-val|closest-to-200|
 --- ---------- ---------- ----------- -------------- 
|  1|2022-11-10|       550|        180|           180|
|  1|2022-10-10|       500|        100|           180|
|  2|2022-11-10|       580|        210|           210|
|  2|2022-10-10|       400|        180|           210|
 --- ---------- ---------- ----------- -------------- 
  • Related