Home > Back-end >  Pyspark Mean value of each element in multiple lists
Pyspark Mean value of each element in multiple lists

Time:03-04

I have a df with 2 columns:

  • id
  • vector

This is a sample of how it looks:

 -------------------- ---------- 
|              vector|        id|
 -------------------- ---------- 
|[8.32,3.22,5.34,6.5]|1046091128|
|[8.52,3.34,5.31,6.3]|1046091128|
|[8.44,3.62,5.54,6.4]|1046091128|
|[8.31,3.12,5.21,6.1]|1046091128|
 -------------------- ---------- 

I want to groupBy appid and take the mean of each element of the vectors. So for example the first value in the aggregated list will be (8.32 8.52 8.44 8.31)/4 and so on.

Any help is appreciated.

CodePudding user response:

You can use posexplode function and then aggregate the column based upon average. Something like below -

from pyspark.sql.functions import  *
from pyspark.sql.types import  *

data = [([8.32,3.22,5.34,6.5], 1046091128 ), ([8.52,3.34,5.31,6.3], 1046091128), ([8.44,3.62,5.54,6.4], 1046091128), ([8.31,3.12,5.21,6.1], 1046091128)]
schema = StructType([ StructField("vector", ArrayType(FloatType())), StructField("id", IntegerType()) ])

df = spark.createDataFrame(data=data,schema=schema)

df.select("id", posexplode("vector")).groupBy("id").pivot("pos").agg(avg("col")).show()

Output would look somewhat like :

 ---------- ----------------- ------------------ ----------------- ----------------- 
|        id|                0|                 1|                2|                3|
 ---------- ----------------- ------------------ ----------------- ----------------- 
|1046091128|8.397500038146973|3.3249999284744263|5.350000023841858|6.325000047683716|
 ---------- ----------------- ------------------ ----------------- ----------------- 

You can rename the columns later if required.

CodePudding user response:

This assumes that you know the length of the array column:

l = 4 #size of array column
df1 = df.select("id",*[F.col("vector")[i] for i in range(l)])
out = df1.groupby("id").agg(F.array([F.mean(i) 
                            for i in df1.columns[1:]]).alias("vector"))

out.show(truncate=False)

 ---------- ---------------------------------------- 
|id        |vector                                  |
 ---------- ---------------------------------------- 
|1046091128|[8.3975, 3.325, 5.35, 6.325000000000001]|
 ---------- ----------------------------------------
  • Related