Home > database >  dropDuplicates with non-numeric condition
dropDuplicates with non-numeric condition


I have a dataframe that looks like this (I have a few more columns, but they aren't relevant):

 ----------- ----------- ---------------                                         
|item_id    |location_id|decision       |
 ----------- ----------- --------------- 
|     111111|  A        |        True   |
|     111111|  A        |        False  |
|     111111|  A        |        False  |
|     222222|  B        |        False  |
|     222222|  B        |        False  |
|     333333|  C        |        True   |
|     333333|  C        |        True   |
|     333333|  C        |        Unsure |
 ----------- ----------- --------------- 

I would like to do dropDuplicates("item_id", "location_id") so I can remove rows that have the same item_id and location_id, but I want to keep a row that contains TrueOR Unsure if it exists. If none of the duplicate rows contain True or Unsure, any row with False is fine. For the above example, I would like the resulting dataframe to look like this:

 ----------- ----------- ---------------                                         
|item_id    |location_id|decision       |
 ----------- ----------- --------------- 
|     111111|  A        |        True   |
|     222222|  B        |        False  |
|     333333|  C        |        Unsure |
 ----------- ----------- --------------- 

For item_id 111111 and location_id A, I want the row with decision True since one such row exists. For item_id 222222 and location_id B, since none of the rows contain True, selecting either is fine. For item_id 333333 and location_id C, all rows contain the desired values of either True or Unsure, so selecting any one of the three is fine.

I am using Scala, so solution in Scala would be appreciated.

CodePudding user response:

Here is the code for that:

Input Preparation:

//spark : My SparkSession
import spark.implicits._
  val df = Seq(
    (111111, "A", "True"),
    (111111, "A", "False"),
    (111111, "A", "False"),
    (222222, "B", "False"),
    (222222, "B", "False"),
    (333333, "C", "True"),
    (333333, "C", "True"),
    (333333, "C", "Unsure")
  ).toDF("item_id", "location_id", "decision")

  /** root
    * |-- item_id: integer (nullable = false)
    * |-- location_id: string (nullable = true)
    * |-- decision: string (nullable = true)

Code for achieving the desired output:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

//Step 1) create a WindowSpec : MyWindow is same as that of groupBy("item_id", "location_id") 
//but I want to keep track of the order of True, False and Unsure in that partition 
//so, will order my partition based on the col("decision") which is why we have window functions.
  val MyWindow = Window
    .partitionBy(col("item_id"), col("location_id"))

//Step 2) add row_number to each record in that window (based on the mentioned ordering in MyWindow),
//in this case based on the descending order of col("decision")
    .withColumn("row_number", row_number().over(MyWindow))
//Step 3) It turns out we only need first row from each partition  
//based on the decision to select Unsure (then) True (then) False (based on the order of preference),
//so, we filter in only first row.

 ------- ----------- -------- 
 ------- ----------- -------- 
|111111 |A          |True    |
|222222 |B          |False   |
|333333 |C          |Unsure  |
 ------- ----------- -------- 

EDIT1 (as per comments):

Improved Code (without ordering col("decision") in WindowSpec):

For achieving this, you need to write Custom UserDefinedAggregateFunction for you to have more control over decision attribute range of values, in your requirement it can be like this:

 object MyBestDecisionUDF extends UserDefinedAggregateFunction {

    // step 1) : to set priority score to your decisions which you can configure somewhere
    val decisionOrderMap =
      Map("Unsure" -> 4, "True" -> 3, "False" -> 2, "Zinc" -> 1, "Copper" -> 0)

    /** all overridden functions come from UserDefinedAggregateFunction Abstract Class
    override def inputSchema: StructType = StructType(
      StructField("input_str", StringType, false) :: Nil

    override def bufferSchema: StructType = StructType(
      StructField("buffer_str", StringType, false) :: Nil

    override def dataType: DataType = StringType

    override def deterministic: Boolean = true

    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer.update(0, "")

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      // main step : updating buffer always to hold best decision string value
      if (
        ) < decisionOrderMap.getOrElse(input(0).toString(), -1)
      ) {
        buffer.update(0, input(0))


    override def merge(
        buffer1: MutableAggregationBuffer,
        buffer2: Row
    ): Unit = {}

    override def evaluate(buffer: Row): Any = {


  /** ############################################################### 
    * Calling Custom UDAF
    * ###############################################################


 ------- ----------- -------- 
 ------- ----------- -------- 
|111111 |A          |True    |
|111111 |A          |False   |
|111111 |A          |False   |
|222222 |B          |False   |
|222222 |B          |False   |
|333333 |C          |True    |
|333333 |C          |Unsure  |
|444444 |D          |Copper  |
|444444 |D          |Zinc    |
 ------- ----------- -------- 


  // Custom UDF evaluated column
          .partitionBy(col("item_id"), col("location_id"))
 ------- ----------- ---------------- 
 ------- ----------- ---------------- 
|111111 |A          |True            |
|222222 |B          |False           |
|333333 |C          |Unsure          |
|444444 |D          |Zinc            |
 ------- ----------- ---------------- 

  • Related