Home > front end >  Is there a way to make the following algo more efficient?
Is there a way to make the following algo more efficient?

Time:11-09

Given the following algorithm that returns a value from a predefined enumeration if a random input value u is less than some predefined probability in a (cumulative) prob vector

val <- 1:5
prob <- c(1/3,1/30,2/15,7/30,4/15)
u <- runif(1)
if (u<prob[1]) 
  { 
    x=val[1]
  } else if(u<prob[1] prob[2])
  {
    x=val[2]
  } else if(u<prob[1] prob[2] prob[3])
  {
    x=val[3]
  } else if (u<prob[1] prob[2] prob[3] prob[4])
  {
    x=val[4]
  } else
      x=val[5]
  }

Is there a way to make all this more efficient? I can't figure it out how to do it differently.

CodePudding user response:

An alternative would be to use findInterval. For example

val <- 1:5
prob <- c(1/3,1/30,2/15,7/30,4/15)
u <- runif(1)
val[findInterval(u, cumsum(c(0, prob)))]

This will also work for any number of values

u <- runif(1000)
val[findInterval(u, cumsum(c(0, prob)))]

CodePudding user response:

I don't know whether this is specifically a question about an algorithm or whether you want to solve the problem faster.

sample(val, size = 1, prob = prob) would certainly be faster than your solution, but I don't know whether the internal algorithm is fundamentally any better (e.g. in terms of operation count) than yours.

Looking at the C source code you can see that the algorithm is like yours, except that it computes the cumulative probabilities up-front to avoid repeated addition (and permutes the elements into descending order — I think that's for numerical stability rather than efficiency).

(this is for sampling with replacement; the without-replacement code looks very similar, except that the machinery for removing previously sampled values makes it a little more complicated)

CodePudding user response:

You can use cumsum and which:

set.seed(3)
u <- runif(1)
#[1] 0.702374

val[which(u < cumsum(prob))[1]]
#[1] 4
  •  Tags:  
  • r
  • Related