I have the following dataframe:
d = [ {'id': 3, 'ratio': 1.3 ,'vol1': 100 },
{'id': 5, 'ratio': 0.3 ,'vol1': 200 },
{'id': 1, 'ratio': 1.1 ,'vol1': 300 },
{'id': 8, 'ratio': 0.8 ,'vol1': 400 },
{'id': 2, 'ratio': 2.0 ,'vol1': 500 },
{'id': 4, 'ratio': 0.0 ,'vol1': 600 }
]
data = spark.createDataFrame(d)
To which I have to create an additional column new_col_cond
that is dependent on the values of multiple external lists/arrays (I have also tried with dictionaries), for example:
q1 = [10,20,30,40,50,60,70,80,90]
q1_n = np.array(q1).reshape(-1) #numpy array from above
q2 = [1,2,3,4,5,6,7,8,9]
q2_n = np.array(q2).reshape(-1)
The new column depends on the value of ratio
and selects from either array according to id
as index. I have tried:
data = data.withColumn('new_col_cond', when(col('ratio')<1, q1[col('id')])
.when(col('ratio')>1, q2[col('id')])
) #also with numpy arrays.
with errors coming. I assume that the main source of error is using a column as index for the array, but not sure how else to insert the index into the array. Given the conditional nature of the column I have not tried to join (data is millions of rows and lists are in the thousands).
Due to the size of the dataset I am steering away from Pandas and udfs. The resulting dataframe should look like this:
--- ----- ---- ------------
| id|ratio|vol1|new_col_cond|
--- ----- ---- ------------
| 3| 1.3| 100| 4 |
| 5| 0.3| 200| 60 |
| 1| 1.1| 300| 2 |
| 8| 0.8| 400| 90 |
| 2| 2.0| 500| 3 |
| 4| 0.0| 600| 50 |
--- ----- ---- ------------
Any help in solving this issue is appreciated.
CodePudding user response:
It would be easier to add the 'new_col_cond' into your dictionaries before creating your dataframe.
d = [{'id': 3, 'ratio': 1.3, 'vol1': 100},
{'id': 5, 'ratio': 0.3, 'vol1': 200},
{'id': 1, 'ratio': 1.1, 'vol1': 300},
{'id': 8, 'ratio': 0.8, 'vol1': 400},
{'id': 2, 'ratio': 2.0, 'vol1': 500},
{'id': 4, 'ratio': 0.0, 'vol1': 600}
]
q1 = [10, 20, 30, 40, 50, 60, 70, 80, 90]
q2 = [1, 2, 3, 4, 5, 6, 7, 8, 9]
for d_ in d:
d_['new_col_cond'] = q1[d_['id']] if d_['ratio'] < 1 else q2[d_['id']]
df = spark.createDataFrame(d)
Note:
Whilst this will work for the data you've shown, I'm not sure if it's robust. If the 'id' key has a value >8, this will fail
CodePudding user response:
Create ArrayType column expressions from the numpy arrays and use them in your condition like this:
from pyspark.sql import functions as F
q1_n = F.array(*[F.lit(int(x)) for x in q1_n])
q2_n = F.array(*[F.lit(int(x)) for x in q2_n])
result = data.withColumn(
'new_col_cond',
F.when(F.col('ratio') < 1, q1_n[F.col('id')])
.when(F.col('ratio') > 1, q2_n[F.col('id')])
)
result.show()
# --- ----- ---- ------------
#| id|ratio|vol1|new_col_cond|
# --- ----- ---- ------------
#| 3| 1.3| 100| 4|
#| 5| 0.3| 200| 60|
#| 1| 1.1| 300| 2|
#| 8| 0.8| 400| 90|
#| 2| 2.0| 500| 3|
#| 4| 0.0| 600| 50|
# --- ----- ---- ------------