I have a df with 2 columns:
- id
- vector
This is a sample of how it looks:
-------------------- ----------
| vector| id|
-------------------- ----------
|[8.32,3.22,5.34,6.5]|1046091128|
|[8.52,3.34,5.31,6.3]|1046091128|
|[8.44,3.62,5.54,6.4]|1046091128|
|[8.31,3.12,5.21,6.1]|1046091128|
-------------------- ----------
I want to groupBy appid
and take the mean of each element of the vectors. So for example the first value in the aggregated list will be (8.32 8.52 8.44 8.31)/4 and so on.
Any help is appreciated.
CodePudding user response:
You can use posexplode
function and then aggregate the column based upon average. Something like below -
from pyspark.sql.functions import *
from pyspark.sql.types import *
data = [([8.32,3.22,5.34,6.5], 1046091128 ), ([8.52,3.34,5.31,6.3], 1046091128), ([8.44,3.62,5.54,6.4], 1046091128), ([8.31,3.12,5.21,6.1], 1046091128)]
schema = StructType([ StructField("vector", ArrayType(FloatType())), StructField("id", IntegerType()) ])
df = spark.createDataFrame(data=data,schema=schema)
df.select("id", posexplode("vector")).groupBy("id").pivot("pos").agg(avg("col")).show()
Output would look somewhat like :
---------- ----------------- ------------------ ----------------- -----------------
| id| 0| 1| 2| 3|
---------- ----------------- ------------------ ----------------- -----------------
|1046091128|8.397500038146973|3.3249999284744263|5.350000023841858|6.325000047683716|
---------- ----------------- ------------------ ----------------- -----------------
You can rename the columns later if required.
CodePudding user response:
This assumes that you know the length of the array column:
l = 4 #size of array column
df1 = df.select("id",*[F.col("vector")[i] for i in range(l)])
out = df1.groupby("id").agg(F.array([F.mean(i)
for i in df1.columns[1:]]).alias("vector"))
out.show(truncate=False)
---------- ----------------------------------------
|id |vector |
---------- ----------------------------------------
|1046091128|[8.3975, 3.325, 5.35, 6.325000000000001]|
---------- ----------------------------------------