I am trying to find the cosine similarity between two columns of type array in a pyspark dataframe and add the cosine similarity as a third column, as shown below
Col1 | Col2 | Dot Prod |
---|---|---|
[0.5, 0.6 ... 0.7] | [0.5, 0.3 .... 0.1] | dotProd(Col1, Col2) |
The current implementation I have is:
import pyspark.sql.functions as func
def cosine_similarity(df, col1, col2):
df_cosine = df.select(func.sum(df[col1] * df[col2]).alias('dot'),
func.sqrt(func.sum(df[col1]**2)).alias('norm1'),
func.sqrt(func.sum(df[col2] **2)).alias('norm2'))
d = df_cosine.rdd.collect()[0].asDict()
return d['dot']/(d['norm1'] * d['norm2'])
But I guess the above code only for works for columns with integer values. Is there anyway I would be able to extend the above function to achieve a similar behavior for array columns
CodePudding user response:
Yes above code is for number, not for array of numbers.
You can convert Array of numbers into pyspark Vectors
https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.linalg.Vectors.html
And then call use dense and dot functions.
Example
from pyspark.ml.linalg import Vectors
x = Vectors.dense([1,2,3])
y = Vectors.dense([4,5,6])
x.dot(y)
Similarly, you can also use the norm
function of Vector in pyspark for normalization.
CodePudding user response:
For the array of double, you can use the aggregate
function.
df = spark.createDataFrame([[[0.1, 0.5, 2.0, 1.0], [3.0, 2.4, 0.2, 1.1]]], ['Col1', 'Col2'])
df.show()
-------------------- --------------------
| Col1| Col2|
-------------------- --------------------
|[0.1, 0.5, 2.0, 1.0]|[3.0, 2.4, 0.2, 1.1]|
-------------------- --------------------
df.withColumn('dot', f.expr('aggregate(arrays_zip(Col1, Col2), 0D, (acc, x) -> acc (x.Col1 * x.Col2))')) \
.withColumn('norm1', f.expr('sqrt(aggregate(Col1, 0D, (acc, x) -> acc (x * x)))')) \
.withColumn('norm2', f.expr('sqrt(aggregate(Col2, 0D, (acc, x) -> acc (x * x)))')) \
.withColumn('cosine', f.expr('dot / (norm1 * norm2)')) \
.show(truncate=False)
-------------------- -------------------- --- ----------------- ----------------- ------------------
|Col1 |Col2 |dot|norm1 |norm2 |cosine |
-------------------- -------------------- --- ----------------- ----------------- ------------------
|[0.1, 0.5, 2.0, 1.0]|[3.0, 2.4, 0.2, 1.1]|3.0|2.293468988235943|4.001249804748511|0.3269133956691457|
-------------------- -------------------- --- ----------------- ----------------- ------------------