Home > Software engineering >  Getting unique values in a dataframe by using `df.select(column).distinct().collect()`
Getting unique values in a dataframe by using `df.select(column).distinct().collect()`

Time:08-02

As per my limited understanding about how spark works, when the .collect() action is called, the data in the column column will be partitioned, split among executors, the .distinct() transformation will be applied to each of those partitions and the deduped results will be sent to the driver. But isn't there a chance of records getting duplicated at the driver (since the deduplication took place independently on each executor)? Do we need to apply .distinct() again on the collected result to get rid of duplicates?

CodePudding user response:

Your idea is correct but you are missing one step and that is the reduce phase. Spark executes aggregations very similar to MapReduce. In MapReduce for a distinct aggregation there would be 3 steps

  1. data is read by each mapper (split among executors as you said)
  2. A combiner performs distinct (stll in the mapper process) After all mappers are done we proccedd:
  3. (Missing step) a new reducer process is started (still in the cluster) that aggregates the distinct list from each mapper and performs distinct again.
  4. result is sent to the client.

Spark does the same but unlike MapReduce spark uses the executors for all parts of the execution (mappers/combiners/reducers)

df = spark.createDataFrame([[i] for i in [1,2,2,3,3,3,1,4,5]], ['n'])
df.show()
 ---                                                                            
|  n|
 --- 
|  1|
|  2|
|  2|
|  3|
|  3|
|  3|
|  1|
|  4|
|  5|
 --- 
df_distinct = df.distinct()

df_distinct.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
 - HashAggregate(keys=[n#0L], functions=[])
    - Exchange hashpartitioning(n#0L, 200), ENSURE_REQUIREMENTS, [id=#109]
       - HashAggregate(keys=[n#0L], functions=[])
          - InMemoryTableScan [n#0L]
                - InMemoryRelation [n#0L], StorageLevel(disk, memory, deserialized, 1 replicas)
                      - *(1) Scan ExistingRDD[n#0L]


print(df_distinct.collect())
[Row(n=5), Row(n=1), Row(n=3), Row(n=2), Row(n=4)]

So to understand what Spark does lets look at the Physical plan of df.distinct():

Focusing on these three lines of the plan (the order is bottom up)

(3)  - HashAggregate(keys=[n#0L], functions=[])
(2)    - Exchange hashpartitioning(n#0L, 200), ENSURE_REQUIREMENTS, [id=#109]
(1)       - HashAggregate(keys=[n#0L], functions=[])

A DataFrame/RDD are already partitioned and reside in the executors

(1) HashAggregate - This step does the first distinct at partition level

(2) Exchange hashpartitioning - this stage shuffles data the and gets it to a single executor

(3) HashAggregate - This step does the second distinct on the list of distinct lists

Then the collect() function returns the distinct list to the driver.

  • Related