Home > OS >  DOT Product in pyspark?
DOT Product in pyspark?

Time:04-15

I have:

df1

 ------------------ ---------- 
|               var|multiplier|
 ------------------ ---------- 
|              var1|         1|
|              var2|         2|
|              var3|         3|
 ------------------ ---------- 

df2

 ------- ---------- ----- ----- ---------- --------- 
|   varA|      varB| varC| var1|      var2|     var3|
 ------- ---------- ----- ----- -------------------- 
|   abcd|       at1|    5|    1|        45|       12|
|   xyzw|       vt1|    7|    1|        23|       17|
 ------- ---------- ----------- ---------- --------- 

Result: df3

 ------- ---------- ----- ----- ---------- --------- --------------- 
|   varA|      varB| varC| var1|      var2|     var3|     sumproduct|
 ------- ---------- ----- ----- -------------------- --------------- 
|   abcd|       at1|    5|    1|        90|       36|            127|
|   xyzw|       vt1|    7|    1|        46|       51|             98|
 ------- ---------- ----------- ---------- --------- --------------- 

In python, I am able to achieve this by:

df1 = df1.set_index(['var'])
df3 = df2.dot(df1)

Any help on a similar pyspark way to do the same?

CodePudding user response:

lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()#put multiplier into a list
df3 =(
 df2.withColumn('a1', array('var1',      'var2',     'var3'))#Create an array from df2
 .withColumn('a2', array([F.lit(x) for x in lst]))#Insert array from df1
 .withColumn('a1',expr("transform(a1, (x,i)->a2[i]*x)"))#Compute dot product
 .select('varA','varB','varC','a1', *[F.col('a1')[i].alias(f'var{str(i 1)}') for i in range(3)])#Expand a1 back to original var columns
 .select('*', expr("aggregate(a1,cast(0 as bigint), (x,i) -> x i)").alias('sumproduct'))#sumproduct
 .drop('a1','a2')
 )

df3.show()

 ---- ---- ---- ---- ---- ---- ---------- 
|varA|varB|varC|var1|var2|var3|sumproduct|
 ---- ---- ---- ---- ---- ---- ---------- 
|abcd| at1|   5|   1|  90|  36|       127|
|xyzw| vt1|   7|   1|  46|  51|        98|
 ---- ---- ---- ---- ---- ---- ---------- 

Remember if all you need is the dot product, udf is a possibility. We can use numpy which is very good at such stuff

import numpy as np
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()
dot_array = udf(lambda x,y: int(np.dot(x,y)), IntegerType())
df2.withColumn("dotproduct",dot_array(array('var1',      'var2',     'var3'),array([F.lit(x) for x in lst]))).show()

 ---- ---- ---- ---- ---- ---- ---------- 
|varA|varB|varC|var1|var2|var3|dotproduct|
 ---- ---- ---- ---- ---- ---- ---------- 
|abcd| at1|   5|   1|  45|  12|       127|
|xyzw| vt1|   7|   1|  23|  17|        98|
 ---- ---- ---- ---- ---- ---- ---------- 
  • Related