I have a dataframe that looks like this (I have a few more columns, but they aren't relevant):
----------- ----------- ---------------
|item_id |location_id|decision |
----------- ----------- ---------------
| 111111| A | True |
| 111111| A | False |
| 111111| A | False |
| 222222| B | False |
| 222222| B | False |
| 333333| C | True |
| 333333| C | True |
| 333333| C | Unsure |
----------- ----------- ---------------
I would like to do dropDuplicates("item_id", "location_id")
so I can remove rows that have the same item_id
and location_id
, but I want to keep a row that contains True
OR Unsure
if it exists. If none of the duplicate rows contain True
or Unsure
, any row with False
is fine. For the above example, I would like the resulting dataframe to look like this:
----------- ----------- ---------------
|item_id |location_id|decision |
----------- ----------- ---------------
| 111111| A | True |
| 222222| B | False |
| 333333| C | Unsure |
----------- ----------- ---------------
For item_id
111111 and location_id
A, I want the row with decision
True since one such row exists. For item_id
222222 and location_id
B, since none of the rows contain True, selecting either is fine. For item_id
333333 and location_id
C, all rows contain the desired values of either True
or Unsure
, so selecting any one of the three is fine.
I am using Scala, so solution in Scala would be appreciated.
CodePudding user response:
Here is the code for that:
Input Preparation:
//spark : My SparkSession
import spark.implicits._
val df = Seq(
(111111, "A", "True"),
(111111, "A", "False"),
(111111, "A", "False"),
(222222, "B", "False"),
(222222, "B", "False"),
(333333, "C", "True"),
(333333, "C", "True"),
(333333, "C", "Unsure")
).toDF("item_id", "location_id", "decision")
df.printSchema()
/** root
* |-- item_id: integer (nullable = false)
* |-- location_id: string (nullable = true)
* |-- decision: string (nullable = true)
*/
Code for achieving the desired output:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
//Step 1) create a WindowSpec : MyWindow is same as that of groupBy("item_id", "location_id")
//but I want to keep track of the order of True, False and Unsure in that partition
//so, will order my partition based on the col("decision") which is why we have window functions.
val MyWindow = Window
.partitionBy(col("item_id"), col("location_id"))
.orderBy(desc("decision"))
df
//Step 2) add row_number to each record in that window (based on the mentioned ordering in MyWindow),
//in this case based on the descending order of col("decision")
.withColumn("row_number", row_number().over(MyWindow))
//Step 3) It turns out we only need first row from each partition
//based on the decision to select Unsure (then) True (then) False (based on the order of preference),
//so, we filter in only first row.
.filter(col("row_number").equalTo(1))
.drop(col("row_number"))
.orderBy(col("item_id"))
.show(false)
/**
OUTPUT:
------- ----------- --------
|item_id|location_id|decision|
------- ----------- --------
|111111 |A |True |
|222222 |B |False |
|333333 |C |Unsure |
------- ----------- --------
*/
EDIT1 (as per comments):
Improved Code (without ordering col("decision") in WindowSpec):
For achieving this, you need to write Custom UserDefinedAggregateFunction for you to have more control over decision attribute range of values, in your requirement it can be like this:
object MyBestDecisionUDF extends UserDefinedAggregateFunction {
// step 1) : to set priority score to your decisions which you can configure somewhere
val decisionOrderMap =
Map("Unsure" -> 4, "True" -> 3, "False" -> 2, "Zinc" -> 1, "Copper" -> 0)
/** all overridden functions come from UserDefinedAggregateFunction Abstract Class
*/
override def inputSchema: StructType = StructType(
StructField("input_str", StringType, false) :: Nil
)
override def bufferSchema: StructType = StructType(
StructField("buffer_str", StringType, false) :: Nil
)
override def dataType: DataType = StringType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, "")
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// main step : updating buffer always to hold best decision string value
if (
decisionOrderMap.getOrElse(
buffer.get(0).toString(),
-1
) < decisionOrderMap.getOrElse(input(0).toString(), -1)
) {
buffer.update(0, input(0))
}
}
override def merge(
buffer1: MutableAggregationBuffer,
buffer2: Row
): Unit = {}
override def evaluate(buffer: Row): Any = {
buffer(0)
}
}
/** ###############################################################
* Calling Custom UDAF
* ###############################################################
*/
/**
INPUT:
------- ----------- --------
|item_id|location_id|decision|
------- ----------- --------
|111111 |A |True |
|111111 |A |False |
|111111 |A |False |
|222222 |B |False |
|222222 |B |False |
|333333 |C |True |
|333333 |C |Unsure |
|444444 |D |Copper |
|444444 |D |Zinc |
------- ----------- --------
*/
df
// Custom UDF evaluated column
.withColumn(
"my_best_decision",
MyBestDecisionUDF(col("decision")).over(
Window
.partitionBy(col("item_id"), col("location_id"))
)
)
.drop(col("decision"))
.distinct()
.orderBy(col("item_id"))
.show(false)
/**
* OUTPUT:
------- ----------- ----------------
|item_id|location_id|my_best_decision|
------- ----------- ----------------
|111111 |A |True |
|222222 |B |False |
|333333 |C |Unsure |
|444444 |D |Zinc |
------- ----------- ----------------
*/