Home > database >  Why nested When().Then() is slower than Left Join in Rust Polars?
Why nested When().Then() is slower than Left Join in Rust Polars?

Time:12-04

In Rust Polars(might apply to python pandas as well) assigning values in a new column with a complex logic involving values of other columns can be achieved in two ways. The default way is using a nested WhenThen expression. Another way to achieve same thing is with LeftJoin. Naturally I would expect When Then to be much faster than Join, but it is not the case. In this example, When Then is 6 times slower than Join. Is that actually expected? Am I using When Then wrong?

In this example the goal is to assign weights/multipliers column based on three other columns: country, city and bucket.

use std::collections::HashMap;

use polars::prelude::*;
use rand::{distributions::Uniform, Rng}; // 0.6.5

pub fn bench() {
    // PREPARATION
    // This MAP is to be used for Left Join
    let mut weights = df![
        "country"=>vec!["UK"; 5],
        "city"=>vec!["London"; 5],
        "bucket" => ["1","2","3","4","5"],
        "weights" => [0.1, 0.2, 0.3, 0.4, 0.5]
    ].unwrap().lazy();
    weights = weights.with_column(concat_lst([col("weights")]).alias("weihts"));

    // This MAP to be used in When.Then
    let weight_map = bucket_weight_map(&[0.1, 0.2, 0.3, 0.4, 0.5], 1);


    // Generate the DataSet itself
    let mut rng = rand::thread_rng();
    let range = Uniform::new(1, 5);
    let b: Vec<String> = (0..10_000_000).map(|_| rng.sample(&range).to_string()).collect();
    let rc = vec!["UK"; 10_000_000];
    let rf = vec!["London"; 10_000_000];
    let val = vec![1; 10_000_000];
    let frame = df!(
        "country" => rc,
        "city" => rf,
        "bucket" => b,
        "val" => val,
    ).unwrap().lazy();

    // Test with Left Join
    use std::time::Instant;
    let now = Instant::now();
    let r = frame.clone()
        .join(weights, [col("country"), col("city"), col("bucket")], [col("country"), col("city"), col("bucket")], JoinType::Left)
        .collect().unwrap();
    let elapsed = now.elapsed();
    println!("Left Join took: {:.2?}", elapsed);

    // Test with nested When Then
    let now = Instant::now();
    let r1 = frame.clone().with_column(
        when(col("country").eq(lit("UK")))
            .then(
                when(col("city").eq(lit("London")))
                .then(rf_rw_map(col("bucket"),weight_map,NULL.lit()))
                .otherwise(NULL.lit())
            )
            .otherwise(NULL.lit())
        )
        .collect().unwrap();
    let elapsed = now.elapsed();
    println!("Chained When Then: {:.2?}", elapsed);

    // Check results are identical
    dbg!(r.tail(Some(10)));
    dbg!(r1.tail(Some(10)));
}

/// All this does is building a chained When().Then().Otherwise()
fn rf_rw_map(col: Expr, map: HashMap<String, Expr>, other: Expr) -> Expr {
    // buf is a placeholder
    let mut it = map.into_iter();
    let (k, v) = it.next().unwrap(); //The map will have at least one value

    let mut buf = when(lit::<bool>(false)) // buffer WhenThen
        .then(lit::<f64>(0.).list()) // buffer WhenThen, needed to "chain on to"
        .when(col.clone().eq(lit(k)))
        .then(v);

    for (k, v) in it {
        buf = buf
            .when(col.clone().eq(lit(k)))
            .then(v);
    }
    buf.otherwise(other)
}

fn bucket_weight_map(arr: &[f64], ntenors: u8) -> HashMap<String, Expr> {
    let mut bucket_weights: HashMap<String, Expr> = HashMap::default();
    for (i, n) in arr.iter().enumerate() {
        let j = i   1;
        bucket_weights.insert(
            format!["{j}"],
            Series::from_vec("weight", vec![*n; ntenors as usize])
                .lit()
                .list(),
        );
    }
    bucket_weights
}

The result is surprising to me: Left Join took: 561.26ms vs Chained When Then: 3.22s

Thoughts?

UPDATE

This does not make much difference. Nested WhenThen is still over 3s

// Test with nested When Then
    let now = Instant::now();
    let r1 = frame.clone().with_column(
        when(col("country").eq(lit("UK")).and(col("city").eq(lit("London"))))
            .then(rf_rw_map(col("bucket"),weight_map,NULL.lit()))
            .otherwise(NULL.lit())
        )
        .collect().unwrap();
    let elapsed = now.elapsed();
    println!("Chained When Then: {:.2?}", elapsed);

CodePudding user response:

The joins are one of the most optimized algorithms in polars. A left join will be executed fully in parallel and has many performance related fast paths. If you want to combine data based on equality, you should almost always choose a join.

CodePudding user response:

It's difficult to say for certain without more context, but the difference in performance between using a nested When().Then() expression and a LeftJoin in Rust Polars may be due to the implementation of each method. LeftJoin is likely more optimized for this kind of operation than a nested When().Then() expression, so it may be faster in general. Additionally, using LeftJoin may allow the program to take advantage of parallelization, which can improve performance. It's also possible that the specific inputs to the two methods in the example are causing the LeftJoin to be faster.

  • Related