I defined a dictionary A
and would like to find the keys given a batch of values a
:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
The expected output should be:
keys = [[1,2,3],[0,3,2]]
Why am I getting None
as an output?
CodePudding user response:
JAX transforms like vmap
work by tracing the function, meaning they replace the value with an abstract representation of the value to extract the sequence of operations encoded in the function (See How to think in JAX for a good intro to this concept).
What this means is that to work correctly with vmap
, a function can only use JAX methods, not numpy methods, so your use of np.array_equal
breaks the abstraction.
Unfortunately, there's not really any replacement for it, because there's no mechanism to look up an abstract JAX value in a concrete Python dictionary. If you want to do dict lookups of JAX values, you should avoid transforms and just use Python loops:
keys = jnp.array([[get_key(x, A) for x in row] for row in a])
On the other hand, I suspect this is more of an XY problem; your goal is not to look up dictionary values within a jax transform, but rather to solve some problem. Perhaps you should ask a question about how to solve the problem, rather than how to get around an issue with the solution you have tried.
But if you're willing to not directly use the dict, an alternative get_key
implementation that is compatible with JAX might look something like this:
def get_key(val, my_dict):
keys = jnp.array(list(my_dict.keys()))
values = jnp.array(list(my_dict.values()))
return keys[jnp.where((values == val).all(-1), size=1)]