I'm wondering how to pivot and maintain order in aggregation.
DataFrame
-------------- ------------- ---------------- -----
|CITY |country |STATE |ORDER|
-------------- ------------- ---------------- -----
|3. Corning |United States|New York |3 |
|1. Albany |United States|New York |1 |
|2. Batavia |United States|New York |2 |
|3. Campbell |United States|California |3 |
|2. Bakersfield|United States|California |2 |
|1. Arvin |United States|California |1 |
|2. Tofino |Canada |British Columbia|2 |
|3. Vancouver |Canada |British Columbia|3 |
|1. Cranbrook |Canada |British Columbia|1 |
-------------- ------------- ---------------- -----
Actual Result
------------- ------------------------------------- ------------------------------------- ---------------------------------
|COUNTRY |British Columbia |California |New York |
------------- ------------------------------------- ------------------------------------- ---------------------------------
|United States| |3. Campbell; 2. Bakersfield; 1. Arvin|3. Corning; 1. Albany; 2. Batavia|
|Canada |2. Tofino; 3. Vancouver; 1. Cranbrook| | |
------------- ------------------------------------- ------------------------------------- ---------------------------------
Desire Output
------------- ------------------------------------- ------------------------------------- ---------------------------------
|COUNTRY |British Columbia |California |New York |
------------- ------------------------------------- ------------------------------------- ---------------------------------
|United States| |1. Arvin; 2. Bakersfield; 3. Campbell|1. Albany; 2. Batavia; 3. Corning|
|Canada |1. Cranbrook; 2. Tofino; 3. Vancouver| | |
------------- ------------------------------------- ------------------------------------- ---------------------------------
Sample Code
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_list;
import static org.apache.spark.sql.functions.lit;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
public class TestPivotOrder implements Serializable {
private static final long serialVersionUID = -1L;
private static String COUNTRY = "COUNTRY";
private static String STATE = "STATE";
private static String CITY = "CITY";
private static String ORDER = "ORDER";
public static void main(String[] args) {
TestPivotOrder app = new TestPivotOrder();
app.start();
}
private void start() {
// display warning and error messages only
Logger.getLogger("org.apache").setLevel(Level.WARN);
SparkSession sparkSession = SparkSession.builder()
.appName("Pivot with Order").master("local").getOrCreate();
Dataset<Row> df = getData(sparkSession);
df = df
.withColumn(ORDER, functions.regexp_extract(col(STATE), "(\\d )", 0))
.withColumn(STATE, functions.regexp_replace(col(STATE), ".(\\d )", ""))
.withColumn(CITY, functions.concat(col(ORDER), lit(". "), col(CITY)));
/*
-------------- ------------- ---------------- -----
|CITY |country |STATE |ORDER|
-------------- ------------- ---------------- -----
|3. Corning |United States|New York |3 |
|1. Albany |United States|New York |1 |
|2. Batavia |United States|New York |2 |
|3. Campbell |United States|California |3 |
|2. Bakersfield|United States|California |2 |
|1. Arvin |United States|California |1 |
|2. Tofino |Canada |British Columbia|2 |
|3. Vancouver |Canada |British Columbia|3 |
|1. Cranbrook |Canada |British Columbia|1 |
-------------- ------------- ---------------- -----
*/
df.show(10, false);
df = df
.groupBy(df.col(COUNTRY))
.pivot(df.col(STATE))
.agg(functions.concat_ws("; ", collect_list(df.col(CITY))));
df.show(10, false);
sparkSession.stop();
}
private Dataset<Row> getData(SparkSession spark) {
List<Item> list = Arrays.asList(
new Item("United States","New York.3","Corning"),
new Item("United States","New York.1","Albany"),
new Item("United States","New York.2","Batavia"),
new Item("United States","California.3","Campbell"),
new Item("United States","California.2","Bakersfield"),
new Item("United States","California.1","Arvin"),
new Item("Canada","British Columbia.2","Tofino"),
new Item("Canada","British Columbia.3","Vancouver"),
new Item("Canada","British Columbia.1","Cranbrook")
);
return spark.createDataFrame(list, Item.class);
}
}
Item Class
import java.io.Serializable;
public class Item implements Serializable {
private static final long serialVersionUID = -1639L;
private String country;
private String state;
private String city;
public Item(String country, String state, String city) {
this.country = country;
this.state = state;
this.city = city;
}
public String getCountry() {
return country;
}
public String getState() {
return state;
}
public String getCity() {
return city;
}
}
CodePudding user response:
// sort_array then mkString with array
df = df
.groupBy(df.col(COUNTRY))
.pivot(df.col(STATE))
.agg(sort_array(collect_list(df.col(CITY))));