Home > Software engineering >  How to improve the speed of `length(unique(x)) == n`?
How to improve the speed of `length(unique(x)) == n`?

Time:11-16

In a custom function, I want to run an if condition if my vector only has one unique value. I can use length(unique(x)) == 1. However, I think that this could be more efficient: instead of getting all the unique values in the vector and then count them, I could just stop after having found one value that is different from the first one:

# Should be TRUE
test <- rep(1, 1e7)

bench::mark(
  length(unique(test)) == 1,
  all(test == test[1])
)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 2 × 6
#>   expression                     min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr>                <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 length(unique(test)) == 1  154.1ms  158.6ms      6.31   166.1MB     6.31
#> 2 all(test == test[1])        38.1ms   49.2ms     19.6     38.1MB     3.92

# Should be FALSE
test2 <- rep(c(1, 2), 1e7)

bench::mark(
  length(unique(test2)) == 1,
  all(test2 == test2[1])
)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 2 × 6
#>   expression                      min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr>                 <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 length(unique(test2)) == 1  341.2ms  386.1ms      2.59   332.3MB     2.59
#> 2 all(test2 == test2[1])       59.5ms   81.1ms     11.5     76.3MB     1.92

It is indeed more efficient.

Now, suppose that I want to replace length(unique(x)) == 2. I could probably do something similar to stop as soon as I find 3 different values but I don't see how can I generalize this to replace length(unique(x)) == n where n can be any positive integer.

Is there an efficient and general way to do this?

(I'm looking for a solution in base R, and if you can improve the benchmark for n = 1, feel free to suggest).

CodePudding user response:

For n = 1, a simple Rcpp solution may be fastest.

For variable values of n, take advantage of the nmax argument in duplicated. It isn't much slower than the all(test == test[1]) solution for n = 1.

f <- function(x, n = 1L) {
  if (n == 1L) {
    all(x == x[1])
  } else {
    try({
      out <- FALSE
      out <- sum(!duplicated(x, nmax = max(n - 1L, 2L))) == n
    },
    silent = TRUE)
    out
  }
}

If you're willing to use the data.table package, uniqueN is even faster.

library(data.table)

test1 <- integer(1e7)
test2 <- replace(test1, sample(length(test1), 1), 1L)
test3 <- replace(test1, sample(length(test1), 2), 1:2)
test4 <- replace(test1, sample(length(test1), 3), 1:3)

microbenchmark::microbenchmark(test11 = f(test1),
                               test11u = uniqueN(test1) == 1L,
                               test21 = !(f(test2)),
                               test21u = !(uniqueN(test2) == 1L),
                               test12 = !(f(test1, 2L)),
                               test12u = !(uniqueN(test1) == 2L),
                               test22 = f(test2, 2L),
                               test22u = uniqueN(test2) == 2L,
                               test32 = !(f(test3, 2L)),
                               test32u = !(uniqueN(test3) == 2L),
                               test33 = f(test3, 3L),
                               test33u = uniqueN(test3) == 3L,
                               test42 = !(f(test4, 2L)),
                               test42u = !(uniqueN(test4) == 2L),
                               times = 10L,
                               check = "identical"
)
#> Unit: milliseconds
#>     expr     min      lq     mean   median      uq     max neval
#>   test11 19.7039 21.0140 24.53535 25.34470 28.1465 29.2522    10
#>  test11u  8.4288  8.4487 11.75907  9.01450 15.7135 17.6662    10
#>   test21 15.6819 16.3501 19.96841 20.67490 23.4805 24.5174    10
#>  test21u 22.0334 22.9321 24.48876 23.89990 24.9918 30.1270    10
#>   test12 56.7494 58.1082 61.04971 59.28220 63.2788 70.6369    10
#>  test12u  8.2359  8.2733 11.00223  8.80760 13.6347 20.3073    10
#>   test22 56.0069 58.3666 60.97008 59.13500 65.1769 66.3205    10
#>  test22u 23.6497 24.5786 27.64053 27.53345 29.6208 33.7244    10
#>   test32 56.1626 62.4090 68.80590 63.91255 68.2029 98.7681    10
#>  test32u 22.2644 22.5803 26.76258 23.41385 32.5339 33.6664    10
#>   test33 59.1070 62.5212 63.17074 63.30035 64.3413 66.1363    10
#>  test33u 22.4709 22.7972 25.10889 23.01645 28.1002 30.9884    10
#>   test42 38.3762 44.5015 44.65828 44.81795 46.0175 51.4117    10
#>  test42u 22.4919 23.0305 24.61079 23.68880 24.0825 29.8777    10

CodePudding user response:

To answer the second bit: I suspect if(mean(x) == x[1]) will be faster yet, with the usual warning (which applies to your code snippets as well) about comparing floats.

Another thing to check out might be length(rle(sort(x)$values) {disclaimer: I love rle for everything} .

[and a demerit for me posting this before running microbenchmark on samples]

  •  Tags:  
  • r
  • Related