How I can count the number of repetitive positive or negative elements in each row?
Suppose I have the following data:
ski 2020 2021 2022 2023 2024 2025
book 1.2 5.6 8.4 -2 -5 6
jar 4.2 -5 -8 2 4 6
kook -4 -5.2 -2.3 -5.6 -7 8
The output is a list for each row that counts the number of similar signs. For example in the first row we have 3 positive elements and then 2 negative and again one positive. So the output is [3,-2,1]
.
and for 2 other rows the output is as follows:
jar [1,-2,3]
kook [-5,1]
CodePudding user response:
You can do it with a user-defined function using Python's itertools.groupby
(lambda x: (1, -1)[x<0]
is the sign function)
df.show()
# ---- ------ ------ ------ ------ ---- ----
# | 0| 1| 2| 3| 4| 5| 6|
# ---- ------ ------ ------ ------ ---- ----
# | ski|2020.0|2021.0|2022.0|2023.0|2024|2025|
# |book| 1.2| 5.6| 8.4| -2.0| -5| 6|
# | jar| 4.2| -5.0| -8.0| 2.0| 4| 6|
# |kook| -4.0| -5.2| -2.3| -5.6| -7| 8|
# ---- ------ ------ ------ ------ ---- ----
from pyspark.sql.functions import udf, array
from itertools import groupby
from pyspark.sql.types import IntegerType, ArrayType
def count_signs(l):
return [(s*len(list(g))) for s, g in groupby(map(lambda x: (1, -1)[x<0], l))]
count_signs_udf = udf(count_signs, ArrayType(IntegerType()))
df.withColumn('signs', count_signs_udf(array(df.columns[1:]))).show()
# ---- ------ ------ ------ ------ ---- ---- ----------
# | 0| 1| 2| 3| 4| 5| 6| signs|
# ---- ------ ------ ------ ------ ---- ---- ----------
# | ski|2020.0|2021.0|2022.0|2023.0|2024|2025| [6]|
# |book| 1.2| 5.6| 8.4| -2.0| -5| 6|[3, -2, 1]|
# | jar| 4.2| -5.0| -8.0| 2.0| 4| 6|[1, -2, 3]|
# |kook| -4.0| -5.2| -2.3| -5.6| -7| 8| [-5, 1]|
# ---- ------ ------ ------ ------ ---- ---- ----------