Home > Net >  Adding None to PySpark array
Adding None to PySpark array

Time:07-02

I want to create an array which is conditionally populated based off of existing column and sometimes I want it to contain None. Here's some example code:

from pyspark.sql import Row
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, array, lit
 
spark = SparkSession.builder.getOrCreate()
 
df = spark.createDataFrame([
    Row(ID=1),
    Row(ID=2),
    Row(ID=2),
    Row(ID=1)
])

value_lit = 0.45
size = 10

df = df.withColumn("TEST",when(df["ID"] == 2,array([None for i in range(size)])).otherwise(array([lit(value_lit) for i in range(size)])))

df.show(truncate=False)

And here's the error I'm getting:

TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

I know it isn't a string or column, I don't see why it has to be?

  • lit: doesn't work.
  • array: I'm not sure how to use array in this context.
  • struct: probably the way to go but I'm not sure how to use it here. Perhaps I have to set an option to allow the new column to contain None values?
  • create_map: I'm not creating a key:value map so I'm sure this is not the correct one to use.

CodePudding user response:

Try this is working for me, (lit before array):

from pyspark.sql import Row
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, array, lit

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame([
    Row(ID=1),
    Row(ID=2),
    Row(ID=2),
    Row(ID=1)
])

value_lit = 0.45
size = 10

df = df.withColumn("TEST",when(df["ID"] == 2,array([lit(None) for i in range(size)])).otherwise(array([lit(value_lit) for i in range(size)])))

df.show(truncate=False)

Output:

enter image description here

CodePudding user response:

The condition must be flipped: F.when(F.col('ID') != 2, value_lit)

If you do it, you don't need otherwise at all. If when condition is not satisfied, the result is always null.

Also, just one list comprehension is enough.

from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.getOrCreate()
 
df = spark.createDataFrame([(1,), (2,), (2,), (1,)], ['ID'])

value_lit = 0.45
size = 10

df = df.withColumn("TEST", F.array([F.when(F.col('ID') != 2, value_lit) for i in range(size)]))

df.show(truncate=False)
#  --- ------------------------------------------------------------ 
# |ID |TEST                                                        |
#  --- ------------------------------------------------------------ 
# |1  |[0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45]|
# |2  |[,,,,,,,,,]                                                 |
# |2  |[,,,,,,,,,]                                                 |
# |1  |[0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45]|
#  --- ------------------------------------------------------------ 

I've run this code on Spark 2.4.3.

  • Related