Home > Back-end >  Faster way to multiply combination of variables in a data frame
Faster way to multiply combination of variables in a data frame

Time:07-24

I would like someone to suggest a faster way complete the following steps:

  1. Using the variables in orig_df, generate unique combination of variables.
  2. Multiply the unique combination of variables from step 1 and cbind the new variables to orig_df.

My real data frame contains anywhere from 1,000 to 5,000 variables (columns) so my current code on my real world data frames is quite slow as the combination of variables grows significantly as the number of original variables increases. For the sake of simplicity, I have included a reprex data frame with 24 variables.

library(tidyverse)
library(data.table)
   
## reprex data
orig_df <- structure(list(name1 = c("a", "a", "a", "a", "a", "a", "a", "a", 
                                    "a", "a"), name2 = c("b", "b", "b", "b", "b", "b", "b", "b", 
                                                         "b", "b"), a1 = c(2.02914670105014, 3.4823452794778, 2.10650327936948, 
                                                                           2.60655830350524, 1.96007832755073, 2.20319900226866, 2.2695046408281, 
                                                                           2.21424994202857, 2.28331831552798, 2.7281186408642), a2 = c(2.34996652741376, 
                                                                                                                                        3.70095020360926, 2.23917726097417, 1.78633957970163, 2.2069085425937, 
                                                                                                                                        2.50808324747805, 2.10149739588418, 2.40052085287649, 2.15312734529292, 
                                                                                                                                        2.26983254343561), a3 = c(2.29236552982738, 3.89205794602908, 
                                                                                                                                                                  2.00332958161127, 2.24914520112217, 2.53007733770606, 2.71754551346492, 
                                                                                                                                                                  2.0114333932435, 2.86503488517146, 2.24698418468691, 2.01899695076691
                                                                                                                                        ), a4 = c(3.41316237098706, 4.74677617124502, 2.16724250299177, 
                                                                                                                                                  2.69085298208423, 2.92152132542834, 2.77166460573895, 2.11240604408321, 
                                                                                                                                                  3.69400496418272, 2.60149397515834, 2.55143073028433), a5 = c(181, 
                                                                                                                                                                                                                153, 157, 164, 186, 160, 189, 184, 156, 144), a6 = c(109.505743217992, 
                                                                                                                                                                                                                                                                     124.993679689496, 118.331551171713, 121.149943117191, 117.845938184937, 
                                                                                                                                                                                                                                                                     116.653468860762, 115.497212844565, 119.598619239394, 118.71099689067, 
                                                                                                                                                                                                                                                                     119.980353824938), a7 = c(97.9342963550266, 101.648325143684, 
                                                                                                                                                                                                                                                                                               95.6506939872352, 99.3679362781565, 96.8558088142306, 97.807547332176, 
                                                                                                                                                                                                                                                                                               95.8501286576493, 97.0383412710708, 97.2285616820599, 96.4837667978641
                                                                                                                                                                                                                                                                     ), a8 = c(82.9441585205545, 141.582901736505, 141.917096237054, 
                                                                                                                                                                                                                                                                               133.01571492792, 139.3097949682, 159.139671804723, 135.367408753504, 
                                                                                                                                                                                                                                                                               146.93593535641, 142.989812366132, 153.648429482203), a9 = c(207.440039573019, 
                                                                                                                                                                                                                                                                                                                                            226.64200483318, 213.982245158948, 220.517879395348, 214.701746999167, 
                                                                                                                                                                                                                                                                                                                                            214.461016192938, 211.347341502214, 216.636960510465, 215.939558572729, 
                                                                                                                                                                                                                                                                                                                                            216.464120622803), a10 = c(42, 52.475, 52.475, 49.475, 52.475, 
                                                                                                                                                                                                                                                                                                                                                                       46.475, 52.475, 49.475, 49.475, 52.475), a11 = c(86.47, 98.01, 
                                                                                                                                                                                                                                                                                                                                                                                                                        93.23, 97.06, 94.23, 89.82, 92.05, 92.34, 97.02, 94.53), a12 = c(161.545, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         179.1775, 166.61, 170.9975, 166.35, 164.6125, 161.65, 171.7475, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         169.29, 170.7975), a13 = c(299.982, 317.99, 300.538, 311.906, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    305.2085, 306.263, 299.855, 305.34, 304.7145, 304.223), a14 = c(44.09, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    88.66, 57.84, 78.75, 73.51, 64.51, 88.59, 74.5, 77.84, 84.08), 
                          a15 = c(79.08097, 90.1388, 85.24364, 84.78847, 84.34771, 
                                  84.42352, 83.96158, 85.82465, 83.3913, 86.80556), a16 = c(26.6, 
                                                                                            38.35, 35.075, 34.075, 34.7125, 30.075, 35.35, 32.7125, 33.7125, 
                                                                                            37.35), a17 = c(61.3, 73.15, 70.45, 69.45, 69, 68.45, 70.15, 
                                                                                                            68, 68, 72.15), a18 = c(120.55, 130.45, 121.25, 123, 122.5, 
                                                                                                                                    122, 122.45, 124.5, 122.5, 126.45), a19 = c(156.35, 166.7, 
                                                                                                                                                                                157.4, 157.4, 158, 158.4, 156.7, 159, 158, 157.7), a20 = c(211, 
                                                                                                                                                                                                                                           225.35, 212.425, 214.425, 218.15, 216.425, 211, 220.15, 215.15, 
                                                                                                                                                                                                                                           214), a21 = c(292.221, 304.44, 284.18, 288.88, 293.07, 289.45, 
                                                                                                                                                                                                                                                         284.98, 297.26, 289.5, 289.45), a22 = c(113.40921, 128.19039, 
                                                                                                                                                                                                                                                                                                 120.84558, 120.76845, 119.74005, 120.69014, 120.43093, 121.66397, 
                                                                                                                                                                                                                                                                                                 118.77983, 123.26633)), row.names = c(NA, -10L), class = c("data.table", 
                                                                                                                                                                                                                                                                                                                                                            "data.frame"))

rankBy <- names(orig_df)[-c(1:2)]

## STEP 1:
## variable combination step (need to use "rankBy" vector because real df contains variables that are not used in the combination step)
colsToMultiply <- orig_df %>%
  dplyr::select(all_of(rankBy)) %>%
  names() %>%
  data.frame()
originalCols <- colsToMultiply                                          ## vector of original column names
colsToMultiply <- expand.grid(colsToMultiply$., colsToMultiply$.)
colsToMultiply <- colsToMultiply %>%
  dplyr::mutate_all(~as.character(.))
colsToMultiply$dups <- apply(colsToMultiply[,1:ncol(colsToMultiply)], 1, function(i) any(duplicated(i, incomparables = NA)))
colsToMultiply <- colsToMultiply[!colsToMultiply$dups, ]
colsToMultiply <- colsToMultiply[,-c(ncol(colsToMultiply))]
colsToMultiply <- colsToMultiply[!duplicated(t(apply(colsToMultiply[,c(1:ncol(colsToMultiply))], 1, sort))), ]  # remove: a,b ; b,a duplicates row wise

## split cols into chunks (easier on memory)
chunks <- 100
n <- nrow(colsToMultiply)
r <- rep(1:ceiling(n/chunks),each=chunks)[1:n]
colgridRunList <- split(colsToMultiply, r)
rm(list = setdiff(ls(), c("orig_df","colgridRunList","rankBy")))
gc()

## STEP 2:
## here we ADD or MULTIPLY two variables (columns) and cbind the new variable to original dataframe
orig_df <- setDT(orig_df)
newVarsList <- vector("list", length = length(colgridRunList))
for (x in seq_along(colgridRunList)) {
  colgrid <- colgridRunList[[x]]
  colList <- vector("list", length = nrow(colgrid))
  for (i in 1:nrow(colgrid)) {
    colselect1 <- colgrid[i,1]
    colselect2 <- colgrid[i,2]
    colname <- paste0(colselect1,"*",colselect2)
    new_var <- orig_df[,..colselect1] * orig_df[,..colselect2]
    colnames(new_var)[1] <- colname
    colList[[i]] <- new_var
  }
  colList <- do.call(cbind, colList)
  newVarsList[[x]] <- colList
}
newVarsList <- do.call(cbind,newVarsList)
orig_df <- cbind(orig_df, newVarsList)
rm(list = setdiff(ls(), c("orig_df")))
gc()

CodePudding user response:

Instead of creating extra combinations and then removing those and its mirror copies, it may be easier with combn

library(dplyr)
library(purrr)
sub_df <- orig_df %>% 
  select(all_of(rankBy)) 
out <- combn(sub_df, 2, FUN = function(x) x %>% 
         reduce(`*`))
colnames(out) <- combn(names(sub_df), 2, FUN = paste, collapse = "*")

-testing

> dim(out)
[1]  10 231
> dim(newVarsList)
[1]  10 231
> out[, "a14*a20"]
 [1]  9302.99 19979.53 12286.66 16885.97 16036.21 13961.58 18692.49 16401.17 16747.28 17993.12
> newVarsList[["a20*a14"]]
 [1]  9302.99 19979.53 12286.66 16885.97 16036.21 13961.58 18692.49 16401.17 16747.28 17993.12

Benchmarks

On the example data itself, it is faster with combn

> library(microbenchmark)
> microbenchmark(old = old_fn(), new = new_fn())
Unit: milliseconds
 expr       min        lq      mean    median        uq      max neval cld
  old 191.41567 207.43176 215.95941 212.13270 215.19806 547.4069   100   b
  new  71.99214  76.53548  83.51679  79.47683  84.58893 209.6556   100  a 
  • Related