I have a sample dataset with salaries. I want to distribute that salary into 3 buckets and then find the lower of the salary in each bucket and then convert that into an array and attach it to the original set. I am trying to use window function to do that. And it seems to do it in a progressive fashion.
Here is the code that I have written
val spark = sparkSession
import spark.implicits._
val simpleData = Seq(("James", "Sales", 3000),
("Michael", "Sales", 3100),
("Robert", "Sales", 3200),
("Maria", "Finance", 3300),
("James", "Sales", 3400),
("Scott", "Finance", 3500),
("Jen", "Finance", 3600),
("Jeff", "Marketing", 3700),
("Kumar", "Marketing", 3800),
("Saif", "Sales", 3900)
)
val df = simpleData.toDF("employee_name", "department", "salary")
val windowSpec = Window.orderBy("salary")
val ntileFrame = df.withColumn("ntile", ntile(3).over(windowSpec))
val lowWindowSpec = Window.partitionBy("ntile")
val ntileMinDf = ntileFrame.withColumn("lower_bound", min("salary").over(lowWindowSpec))
var rangeDf = ntileMinDf.withColumn("range", collect_set("lower_bound").over(windowSpec))
rangeDf.show()
I am getting the dataset like this
------------- ---------- ------ ----- ----------- ------------------
|employee_name|department|salary|ntile|lower_bound| range|
------------- ---------- ------ ----- ----------- ------------------
| James| Sales| 3000| 1| 3000| [3000]|
| Michael| Sales| 3100| 1| 3000| [3000]|
| Robert| Sales| 3200| 1| 3000| [3000]|
| Maria| Finance| 3300| 1| 3000| [3000]|
| James| Sales| 3400| 2| 3400| [3000, 3400]|
| Scott| Finance| 3500| 2| 3400| [3000, 3400]|
| Jen| Finance| 3600| 2| 3400| [3000, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
------------- ---------- ------ ----- ----------- ------------------
I am expecting the dataset to look like this
------------- ---------- ------ ----- ----------- ------------------
|employee_name|department|salary|ntile|lower_bound| range|
------------- ---------- ------ ----- ----------- ------------------
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
------------- ---------- ------ ----- ----------- ------------------
CodePudding user response:
To ensure that your windows take into account all rows and not only rows before current row, you can use rowsBetween
method with Window.unboundedPreceding
and Window.unboundedFollowing
as argument. Your last line thus become:
var rangeDf = ntileMinDf.withColumn(
"range",
collect_set("lower_bound")
.over(Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
)
and you get the following rangeDf
dataframe:
------------- ---------- ------ ----- ----------- ------------------
|employee_name|department|salary|ntile|lower_bound| range|
------------- ---------- ------ ----- ----------- ------------------
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
------------- ---------- ------ ----- ----------- ------------------