Home > Enterprise >  check if values are within intervals in pyspark
check if values are within intervals in pyspark

Time:12-15

I have a large DataFrame A with intervals like this:

df_a = spark.createDataFrame([
    (0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))
#  --- --- 
# | id|  x|
#  --- --- 
# |  0| 23|
# |  1|  6|
# |  2| 55|
# |  3|  1|
# |  4| 12|
# |  5| 51|
#  --- --- 

and I have a Dataframe B with sorted non-overlapping closed intervals like this:

df_b = spark.createDataFrame([
    (0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))
#  --- ----- --- 
# | id|start|end|
#  --- ----- --- 
# |  0|    1|  5|
# |  1|    8| 10|
# |  2|   15| 16|
# |  3|   20| 30|
# |  4|   50| 52|
#  --- ----- --- 

I want to check if Values of DataFrame A are contained in one of the intervals of DataFrame B and if so, save the id in a new column (interval_id). My Output-DataFrame should look like this:

id   x          interval_id
0    23         3
1    6          null
2    55         null
3    1          0
4    12         null
5    51         4         

Is there a way to do this efficiently without udfs?

CodePudding user response:

Simple left_join should do the job:

from pyspark.sql import functions as F

result = df_a.join(
    df_b.withColumnRenamed("id", "interval_id"),
    F.col("x").between(F.col("start"), F.col("end")),
    "left"
).drop("start", "end")

result.show()

# --- --- ----------- 
#| id|  x|interval_id|
# --- --- ----------- 
#|  0| 23|          3|
#|  1|  6|       null|
#|  2| 55|       null|
#|  3|  1|          0|
#|  4| 12|       null|
#|  5| 51|          4|
# --- --- ----------- 

CodePudding user response:

You can join df_a and df_b such that df_a["x"] between df_b["start"] and df_b["end"].


df_a = spark.createDataFrame([
    (0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))

df_b = spark.createDataFrame([
    (0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))


df_a.join(df_b, df_a["x"].between(df_b["start"], df_b["end"]), how="left")\
    .select(df_a["id"], df_a["x"], df_b["id"].alias("interval_id")).show()

Output

 --- --- ----------- 
| id|  x|interval_id|
 --- --- ----------- 
|  0| 23|          3|
|  1|  6|       null|
|  2| 55|       null|
|  3|  1|          0|
|  4| 12|       null|
|  5| 51|          4|
 --- --- ----------- 
  • Related