Home > OS >  How can I compare pairs of columns in a PySpark dataframe and get 1/0 flag?
How can I compare pairs of columns in a PySpark dataframe and get 1/0 flag?

Time:11-02

I have a situation where I need to compare multiple pairs of columns (the number of pairs will vary and can come from a list as shown in below code snippet) and get 1/0 flag for match/mismatch respectively.

col_pairs = ['Marks', 'Qualification']

The first image is source df and second image is expected df.

Source DF

[Expected Results2

This can be done using when condition on a loop, but that is inefficient as it processes one pair at a time and takes a long time for about 1 billion records.

CodePudding user response:

df.colRegex may serve you well. If all the values in columns which match the regex are equal, you get 1. The script is efficient, as everything is done in one select.

Inputs:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('p', 1, 2, 'g', 'm'),
     ('a', 3, 3, 'g', 'g'),
     ('b', 4, 5, 'g', 'g'),
     ('r', 8, 8, 'm', 'm'),
     ('d', 2, 1, 'u', 'g')],
    ['Name', 'Marks_1', 'Marks_2', 'Qualification_1', 'Qualification_2'])

col_pairs = ['Marks', 'Qualification']

Script:

def equals(*cols):
    return (F.size(F.array_distinct(F.array(*cols))) == 1).cast('int')

df = df.select(
    '*',
    *[equals(df.colRegex(f"`^{c}.*`")).alias(f'{c}_result') for c in col_pairs]
)

df.show()
#  ---- ------- ------- --------------- --------------- ------------ -------------------- 
# |Name|Marks_1|Marks_2|Qualification_1|Qualification_2|Marks_result|Qualification_result|
#  ---- ------- ------- --------------- --------------- ------------ -------------------- 
# |   p|      1|      2|              g|              m|           0|                   0|
# |   a|      3|      3|              g|              g|           1|                   1|
# |   b|      4|      5|              g|              g|           0|                   1|
# |   r|      8|      8|              m|              m|           1|                   1|
# |   d|      2|      1|              u|              g|           0|                   0|
#  ---- ------- ------- --------------- --------------- ------------ -------------------- 

Proof of efficiency:

df.explain()
# == Physical Plan ==
# *(1) Project [Name#636, Marks_1#637L, Marks_2#638L, Qualification_1#639, Qualification_2#640, cast((size(array_distinct(array(Marks_1#637L, Marks_2#638L)), true) = 1) as int) AS Marks_result#646, cast((size(array_distinct(array(Qualification_1#639, Qualification_2#640)), true) = 1) as int) AS Qualification_result#647]
#  - Scan ExistingRDD[Name#636,Marks_1#637L,Marks_2#638L,Qualification_1#639,Qualification_2#640]

CodePudding user response:

i'm not sure what loop are you running, but here's an implementation with list comprehension within a select.

data_ls = [
    (10, 11, 'foo', 'foo'),
    (12, 12, 'bar', 'bar'),
    (10, 12, 'foo', 'bar')
]

data_sdf = spark.sparkContext.parallelize(data_ls). \
    toDF(['marks_1', 'marks_2', 'qualification_1', 'qualification_2'])

col_pairs = ['marks','qualification']

data_sdf. \
    select('*',
           *[(func.col(c '_1') == func.col(c '_2')).cast('int').alias(c '_check') for c in col_pairs]
           ). \
    show()

#  ------- ------- --------------- --------------- ----------- ------------------- 
# |marks_1|marks_2|qualification_1|qualification_2|marks_check|qualification_check|
#  ------- ------- --------------- --------------- ----------- ------------------- 
# |     10|     11|            foo|            foo|          0|                  1|
# |     12|     12|            bar|            bar|          1|                  1|
# |     10|     12|            foo|            bar|          0|                  0|
#  ------- ------- --------------- --------------- ----------- ------------------- 

where the list comprehension would yield the following

[(func.col(c '_1') == func.col(c '_2')).cast('int').alias(c '_check') for c in col_pairs]
# [Column<'CAST((marks_1 = marks_2) AS INT) AS `marks_check`'>,
#  Column<'CAST((qualification_1 = qualification_2) AS INT) AS `qualification_check`'>]
  • Related