Home > OS >  lapply slower than hard coding
lapply slower than hard coding

Time:06-28

I have a reasonably large df

library(data.table)

# nrows
x <- 1e7

# data
df <- data.table(group = sample(letters[1:2], x, replace=T)
                 , user = sample(1:x, x, replace=T)
                 , price = sample(1:3, x, replace=T)
                 , quantity = sample(1:3, x, replace=T)
                 ); df

      group    user price quantity
   1:     a 8286968     1        3
   2:     b 8652340     1        3
   3:     a 7388954     1        1
   4:     b 6932335     3        3
   5:     a 1468016     1        2

Here is the timing of the hardcode method:

# hardcode
system.time(
df[, .(price = sum(price), quantity = sum(quantity))
    , .(user, group)
    ][, .(mean_price = mean( price ), mean_quantity = mean(quantity))
      , .(group)
      ]
)

   user  system elapsed 
   2.77    0.28    1.41 

vs that of lapply:

# lapply
x <- c('price', 'quantity')

system.time(
df[, lapply(.SD, \(i) sum(i))
  , .SDcols = x
  , .(user, group)
  ][, lapply(.SD, \(i) mean(i))
    , .SDcols = x
    , .(group)
    ]
)

   user  system elapsed 
  18.86    0.10   17.86 

What could be the cause of the big run time difference?

CodePudding user response:

This is because data.table internal optimization on mean and sum isn't used due to lambda functions \(i) sum(i) and \(i) mean(i):

base::mean function is internally optimised to use data.table's fastmean function. mean() from base is an S3 generic and gets slow with many groups.

lapply isn't causing the problem as it also gets optimized:

The expression dt[, lapply(.SD, fun), by=.] gets optimised to dt[, list(fun(a), fun(b), ...), by=.] where a,b, ... are columns in .SD. This improves performance tremendously.

system.time(
  df[, lapply(.SD, sum)
     , .SDcols = x
     , .(user, group)
  ][, lapply(.SD,  mean)
    , .SDcols = x
    , .(group)
  ]
)

   user  system elapsed 
   2.58    0.39    1.45 
  • Related