The Scenario
I have a dataframe containing the following data:
import pandas as pd
from pyspark.sql.types import ArrayType, StringType, IntegerType, FloatType, StructType, StructField
import pyspark.sql.functions as F
a = [1,2,3]
b = [['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']]
df = pd.DataFrame({
'id': a,
'list1': b,
})
df=spark.createDataFrame(df)
df.printSchema()
df.show()
--- ---------
| id| list1|
--- ---------
| 1|[a, b, c]|
| 2|[d, e, f]|
| 3|[g, h, i]|
--- ---------
I also have a static list containing the following values
list2 = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
What I want to do
I want to compare each value of list2
to each value of list1
in my data, and build an array of 0/1 values with 1 indicating that the value of list2
was or was not present in list1
.
The resulting output should look like this:
--- ----------- -----------------------------
| id| list1 | result |
--- ----------- -----------------------------
| 1| [a, b, c] | [1, 1, 1, 0, 0, 0, 0, 0, 0] |
| 2| [d, e, f] | [0, 0, 0, 1, 1, 1, 0, 0, 0] |
| 3| [g, h, i] | [0, 0, 0, 0, 0, 0, 1, 1, 1] |
--- ----------- -----------------------------
I need the results in this format because I am eventually going to be multiplying the result
arrays by a scaling factor.
My attempt
# Insert the new_list into the dataframe
df = df.withColumn("list2", F.array([F.lit(x) for x in new_list]))
# Get the result arrays
differencer = F.udf(lambda list1, list2: F.array([1 if x in list1 else 0 for x in list2]), ArrayType(IntegerType()))
df = df.withColumn('result', differencer('list1', 'list2'))
df.show()
However, I get the following error:
An error was encountered:
An error occurred while calling o151.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 11.0 failed 4 times, most recent failure: Lost task 0.3 in stage 11.0 (TID 287) (ip-10-0-0-142.ec2.internal executor 8): java.lang.RuntimeException: Failed to run command: /usr/bin/virtualenv -p python3 --system-site-packages virtualenv_application_1665327460183_0007_0
at org.apache.spark.api.python.VirtualEnvFactory.execCommand(VirtualEnvFactory.scala:120)
at org.apache.spark.api.python.VirtualEnvFactory.setupVirtualEnv(VirtualEnvFactory.scala:78)
at org.apache.spark.api.python.PythonWorkerFactory.<init>(PythonWorkerFactory.scala:94)
at org.apache.spark.SparkEnv.$anonfun$createPythonWorker$1(SparkEnv.scala:125)
at scala.collection.mutable.HashMap.getOrElseUpdate(HashMap.scala:86)
at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:125)
at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:162)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:81)
at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:130)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:863)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:863)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:133)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1474)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)
I've tried dozens of iterations and approaches, but literally everything I do results in the above error.
How can I get this to work? Ideally without having to insert list2
into the dataframe prior to running the comparison.
Thanks
CodePudding user response:
The idea is to add list2
as extra column to the dataframe and then use transform to check for each element of the newly added column if it is part of the array in column list1
.
from pyspark.sql import functions as F
df.withColumn("result", F.array(*map(F.lit, list2))) \
.withColumn("result", F.transform("result", lambda v: F.array_contains(F.col("list1"), v).cast("int"))) \
.show(truncate=False)
Output:
--- --------- ---------------------------
|id |list1 |result |
--- --------- ---------------------------
|1 |[a, b, c]|[1, 1, 1, 0, 0, 0, 0, 0, 0]|
|2 |[d, e, f]|[0, 0, 0, 1, 1, 1, 0, 0, 0]|
|3 |[g, h, i]|[0, 0, 0, 0, 0, 0, 1, 1, 1]|
--- --------- ---------------------------
Using the built-in function transform
is faster than an udf as it avoids the overhead that comes with udfs.