Home > Back-end >  I have the code below which I want to translate into pytorch. I'm looking for a way to translat
I have the code below which I want to translate into pytorch. I'm looking for a way to translat

Time:03-22

I need to translate this code to pytorch. The code given below use np.vectorize. I am looking for a pytorch equivalent for this.

class SimplexPotentialProjection(object):
    def __init__(self, potential, inversePotential, strong_convexity_const, precision = 1e-10):
        self.inversePotential = inversePotential
        self.gradPsi = np.vectorize(potential)
        self.gradPsiInverse = np.vectorize(inversePotential)
        self.precision = precision
        self.strong_convexity_const = strong_convexity_const

CodePudding user response:

The doc for numpy.vectorize clearly states that:

The vectorize function is provided primarily for convenience, not for performance. The implementation is essentially a for loop.

Therefore, in order to convert your code to you'll simply need apply potential and inversePotential in a loop over their tensor arguments. However, that might be very inefficient. You would better re-implement your functions to act "natively" in a vectorized manner on tensors.

  • Related