I am trying the following code which replace an empty list with unique array of a column("apples_set") when the condition "all" is satisfied.
The column "apple_logic_string" is of type Array[String]
Data frame looks like this:
apples_patterns.show()
-------------------- -----------------
| apples_logic_string|apples_set |
-------------------- -----------------
| "234" |["43","54"] |
| "65" |["95"] |
| "all" |[] |
| "76" |["84","67"] |
-------------------- -----------------
The code:
unique_all_apples = set(apples_patterns.agg(F.flatten(F.collect_set('apples_set'))).head()[0]) # noqa
error_patterns = apples_patterns.withColumn('apples_set', F.when(F.col('apples_logic_string') == 'all',
unique_all_apples).otherwise(F.col('apples_set')))
The Error:
Traceback (most recent call last):
File "/myproject/datasets/apples_matching.py", line 24, in compute
apples_patterns = apples_patterns.withColumn('apples_set', F.when(F.col('apples_logic_string') == 'all',
File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/pyspark/sql/functions.py", line 1518, in when
jc = sc._jvm.functions.when(condition._jc, v)
File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/py4j/java_gateway.py", line 1321, in __call__
return_value = get_return_value(
File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/pyspark/sql/utils.py", line 111, in deco
return f(*a, **kw)
File "/scratch/asset-install/1c9821b4f6adc95ac4d5f15ff109001b/miniconda38/lib/python3.8/site-packages/py4j/protocol.py", line 326, in get_return_value
raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.when.
: java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [43,54,95,84,67]
CodePudding user response:
The easiest solution is to create another dataframe with one row that contains all distinct apples_set using explode than collect_set, after that joined to the original dataframe:
import spark.implicits._
val data = Seq(
("234", Seq("43", "54")),
("65", Seq("95")),
("all", Seq()),
("76", Seq("84", "67"))
)
val df = spark.sparkContext.parallelize(data).toDF("apples_logic_string", "apples_set")
val allDf = df.select(explode(col("apples_set")).as("apples_set")).agg(collect_set("apples_set").as("all_apples_set"))
.withColumn("apples_logic_string", lit("all"))
df.join(broadcast(allDf), Seq("apples_logic_string"), "left")
.withColumn("apples_set", when(col("apples_logic_string").equalTo("all"), col("all_apples_set")).otherwise(col("apples_set")))
.drop("all_apples_set").show(false)
------------------- --------------------
|apples_logic_string|apples_set |
------------------- --------------------
|234 |[43, 54] |
|65 |[95] |
|all |[84, 95, 67, 54, 43]|
|76 |[84, 67] |
------------------- --------------------
CodePudding user response:
You can use array function: array documentation
In your case you may use it like this:
F.array([F.lit(x) for x in unique_all_apples])
sample code
import pyspark.sql.functions as F
x = [("234", ["43", "54"]), ("65", ["95"]), ("all", []), ("76", ["84", "67"])]
apples_patterns = spark.createDataFrame(x, schema=["apples_logic_string", "apples_set"])
unique_all_apples = set(
apples_patterns.agg(F.flatten(F.collect_set("apples_set"))).head()[0]
)
error_patterns = apples_patterns.withColumn(
"apples_set",
F.when(
F.col("apples_logic_string") == "all",
F.array([F.lit(x) for x in unique_all_apples]),
).otherwise(F.col("apples_set")),
)
And the output:
------------------- --------------------
|apples_logic_string| apples_set|
------------------- --------------------
| 234| [43, 54]|
| 65| [95]|
| all|[54, 95, 43, 67, 84]|
| 76| [84, 67]|
------------------- --------------------