Home > Net >  Spark DataFrame Pivot with sort order in aggregation
Spark DataFrame Pivot with sort order in aggregation

Time:06-24

I'm wondering how to pivot and maintain order in aggregation.

DataFrame

 -------------- ------------- ---------------- ----- 
|CITY          |country      |STATE           |ORDER|
 -------------- ------------- ---------------- ----- 
|3. Corning    |United States|New York        |3    |
|1. Albany     |United States|New York        |1    |
|2. Batavia    |United States|New York        |2    |
|3. Campbell   |United States|California      |3    |
|2. Bakersfield|United States|California      |2    |
|1. Arvin      |United States|California      |1    |
|2. Tofino     |Canada       |British Columbia|2    |
|3. Vancouver  |Canada       |British Columbia|3    |
|1. Cranbrook  |Canada       |British Columbia|1    |
 -------------- ------------- ---------------- ----- 

Actual Result

 ------------- ------------------------------------- ------------------------------------- --------------------------------- 
|COUNTRY      |British Columbia                     |California                           |New York                         |
 ------------- ------------------------------------- ------------------------------------- --------------------------------- 
|United States|                                     |3. Campbell; 2. Bakersfield; 1. Arvin|3. Corning; 1. Albany; 2. Batavia|
|Canada       |2. Tofino; 3. Vancouver; 1. Cranbrook|                                     |                                 |
 ------------- ------------------------------------- ------------------------------------- --------------------------------- 

Desire Output

 ------------- ------------------------------------- ------------------------------------- --------------------------------- 
|COUNTRY      |British Columbia                     |California                           |New York                         |
 ------------- ------------------------------------- ------------------------------------- --------------------------------- 
|United States|                                     |1. Arvin; 2. Bakersfield; 3. Campbell|1. Albany; 2. Batavia; 3. Corning|
|Canada       |1. Cranbrook; 2. Tofino; 3. Vancouver|                                     |                                 |
 ------------- ------------------------------------- ------------------------------------- --------------------------------- 

Sample Code

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_list;
import static org.apache.spark.sql.functions.lit;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;

public class TestPivotOrder implements Serializable {
    private static final long serialVersionUID = -1L;
    
    private static String COUNTRY = "COUNTRY";
    private static String STATE = "STATE";
    private static String CITY = "CITY";
    private static String ORDER = "ORDER";
    
    public static void main(String[] args) {
        TestPivotOrder app = new TestPivotOrder();
        app.start();
    }


    private void start() {
        
        // display warning and error messages only
        Logger.getLogger("org.apache").setLevel(Level.WARN);
        
        SparkSession sparkSession = SparkSession.builder()
             .appName("Pivot with Order").master("local").getOrCreate();
        
        Dataset<Row> df = getData(sparkSession);

        df = df
            .withColumn(ORDER, functions.regexp_extract(col(STATE), "(\\d )", 0))
            .withColumn(STATE, functions.regexp_replace(col(STATE), ".(\\d )", ""))
            .withColumn(CITY, functions.concat(col(ORDER), lit(". "), col(CITY)));
        
        /* 
         -------------- ------------- ---------------- ----- 
        |CITY          |country      |STATE           |ORDER|
         -------------- ------------- ---------------- ----- 
        |3. Corning    |United States|New York        |3    |
        |1. Albany     |United States|New York        |1    |
        |2. Batavia    |United States|New York        |2    |
        |3. Campbell   |United States|California      |3    |
        |2. Bakersfield|United States|California      |2    |
        |1. Arvin      |United States|California      |1    |
        |2. Tofino     |Canada       |British Columbia|2    |
        |3. Vancouver  |Canada       |British Columbia|3    |
        |1. Cranbrook  |Canada       |British Columbia|1    |
         -------------- ------------- ---------------- ----- 
         */
        df.show(10, false);
        
        df = df
            .groupBy(df.col(COUNTRY))
            .pivot(df.col(STATE))
            .agg(functions.concat_ws("; ", collect_list(df.col(CITY))));
        
        df.show(10, false);
        
        sparkSession.stop();
    }
        

    private Dataset<Row> getData(SparkSession spark) {
        
        List<Item> list = Arrays.asList(
                new Item("United States","New York.3","Corning"),
                new Item("United States","New York.1","Albany"),
                new Item("United States","New York.2","Batavia"),               
                                
                new Item("United States","California.3","Campbell"),
                new Item("United States","California.2","Bakersfield"),
                new Item("United States","California.1","Arvin"),
                
                new Item("Canada","British Columbia.2","Tofino"),
                new Item("Canada","British Columbia.3","Vancouver"),
                new Item("Canada","British Columbia.1","Cranbrook")             
            );
        
        return spark.createDataFrame(list, Item.class);
    }   
}

Item Class

import java.io.Serializable;

public class Item implements Serializable {

    private static final long serialVersionUID = -1639L;
    
    private String country;
    private String state;
    private String city;

    public Item(String country, String state, String city) {
        this.country = country;
        this.state = state;     
        this.city = city;
    }

    public String getCountry() {
        return country;
    }

    public String getState() {
        return state;
    }

    public String getCity() {
        return city;
    }
}

CodePudding user response:

// sort_array then mkString with array

        df = df
        .groupBy(df.col(COUNTRY))
        .pivot(df.col(STATE))
        .agg(sort_array(collect_list(df.col(CITY))));
  • Related