Home > Net >  Replacing unique array of strings in a row using pyspark
Replacing unique array of strings in a row using pyspark

Time:12-13

I am trying the following code which replace an empty list with unique array of a column("apples_set") when the condition "all" is satisfied.

The column "apple_logic_string" is of type Array[String]

Data frame looks like this:

apples_patterns.show()

 -------------------- ----------------- 
| apples_logic_string|apples_set       |
 -------------------- ----------------- 
|  "234"             |["43","54"]      | 
|  "65"              |["95"]           | 
|  "all"             |[]               | 
|  "76"              |["84","67"]      |  
 -------------------- ----------------- 

The code:

unique_all_apples = set(apples_patterns.agg(F.flatten(F.collect_set('apples_set'))).head()[0])  # noqa
    error_patterns = apples_patterns.withColumn('apples_set', F.when(F.col('apples_logic_string') == 'all',
                                               unique_all_apples).otherwise(F.col('apples_set')))

The Error:

Traceback (most recent call last):
  File "/myproject/datasets/apples_matching.py", line 24, in compute
    apples_patterns = apples_patterns.withColumn('apples_set', F.when(F.col('apples_logic_string') == 'all',
  File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/pyspark/sql/functions.py", line 1518, in when
    jc = sc._jvm.functions.when(condition._jc, v)
  File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/py4j/java_gateway.py", line 1321, in __call__
    return_value = get_return_value(
  File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/pyspark/sql/utils.py", line 111, in deco
    return f(*a, **kw)
  File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/py4j/protocol.py", line 326, in get_return_value
    raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.when.
: java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [43,54,95,84,67]

CodePudding user response:

The easiest solution is to create another dataframe with one row that contains all distinct apples_set using explode than collect_set, after that joined to the original dataframe:

import spark.implicits._
val data = Seq(
  ("234", Seq("43", "54")),
  ("65", Seq("95")),
  ("all", Seq()),
  ("76", Seq("84", "67"))
)

val df = spark.sparkContext.parallelize(data).toDF("apples_logic_string", "apples_set")
val allDf = df.select(explode(col("apples_set")).as("apples_set")).agg(collect_set("apples_set").as("all_apples_set"))
  .withColumn("apples_logic_string", lit("all"))
df.join(broadcast(allDf), Seq("apples_logic_string"), "left")
  .withColumn("apples_set", when(col("apples_logic_string").equalTo("all"), col("all_apples_set")).otherwise(col("apples_set")))
  .drop("all_apples_set").show(false)

 ------------------- -------------------- 
|apples_logic_string|apples_set          |
 ------------------- -------------------- 
|234                |[43, 54]            |
|65                 |[95]                |
|all                |[84, 95, 67, 54, 43]|
|76                 |[84, 67]            |
 ------------------- -------------------- 

CodePudding user response:

You can use array function: array documentation

In your case you may use it like this:

F.array([F.lit(x) for x in unique_all_apples])

sample code

import pyspark.sql.functions as F

x = [("234", ["43", "54"]), ("65", ["95"]), ("all", []), ("76", ["84", "67"])]
apples_patterns = spark.createDataFrame(x, schema=["apples_logic_string", "apples_set"])

unique_all_apples = set(
    apples_patterns.agg(F.flatten(F.collect_set("apples_set"))).head()[0]
)
error_patterns = apples_patterns.withColumn(
    "apples_set",
    F.when(
        F.col("apples_logic_string") == "all",
        F.array([F.lit(x) for x in unique_all_apples]),
    ).otherwise(F.col("apples_set")),
)

And the output:

 ------------------- -------------------- 
|apples_logic_string|          apples_set|
 ------------------- -------------------- 
|                234|            [43, 54]|
|                 65|                [95]|
|                all|[54, 95, 43, 67, 84]|
|                 76|            [84, 67]|
 ------------------- -------------------- 
  • Related