Home > database >  How can I efficiently decode vectors to numpy operations?
How can I efficiently decode vectors to numpy operations?

Time:12-30

I am trying to generate an encoding that maps vectors to operations between numpy arrays, and do so efficiently.

For example if I have two operations np.add & np.foo, and I encode np.add as 0 and np.foo as 1. What is the fastest way to go from a vector np.array([0,1,2]) to np.add(1,2)

Essentially a more efficient way to do:

operation_list = [np.add, np.foo, np.witchcraft]

# In practice this array will be generated
# automatically
arr_of_ops = np.array([[0,1,2], \ #np.add(1,2)
                       [1,1,2], \ #np.foo(1,2)
                       [0,2,2], \ #np.add(2,2)
                       [3,3,2]])  #np.witchcraft(3,2)

#This is the function I want to implement in an #efficient way as possible
def evaluate_encoding(arr_of_ops):
    results = []
    for row in arr_of_ops:
        if row[0] == 0:
            results.append(np.add(row[1],row[2]))
        elif row[0] == 1:
            results.append(np.foo(row[1],row[2]))
        elif row[0] == 2:
            results.append(np.witchcraft(row[1],row[2]))

    return np.array(results)

CodePudding user response:

One way is to use pandas which is fairly efficient as it is vectorized as well.

You can turn your data into a dataframe and query by operation

import numpy as np
import pandas as pd

arr_of_ops = np.array([
    [0, 1, 2],
    [1, 1, 2],
    [0, 2, 2],
    [1, 3, 2]
])

df = pd.DataFrame(data=arr_of_ops, columns=("operation", "a", "b"))
indices_sum = df.operation.eq(0)
indices_diff = df.operation.eq(1)
df["results"] = None
df.results[indices_sum] = df.a[indices_sum]   df.b[indices_sum]
df.results[indices_diff] = df.a[indices_diff] - df.b[indices_diff]
print(df)
#    operation  a  b results
# 0          0  1  2     3.0
# 1          1  1  2    -1.0
# 2          0  2  2     4.0
# 3          1  3  2     1.0

CodePudding user response:

  1. use JAX and you get out of the cpu/gpu/tpu choices, XLA has backends for the hardware of your choice.
  2. use a dict and lose the ifelse (in form of jax.numpy.op:op_code)
  3. lose the loop and use jax.pmap to map ops on data across multiple machines if needed.
  • Related