Home > Net >  Is it safe to read the value of numpy.empty or jax.numpy.empty?
Is it safe to read the value of numpy.empty or jax.numpy.empty?

Time:07-09

In Flax, we typically initialize a model by passing in a random vector and let the library figure the correct shape for the parameters via shape inference. For example, this is what the tutorial did

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

It is worth noting that the concrete value of jnp.ones([1, 28, 28, 1]) does not matter, as shape inference only relies on its shape. I can replace it with jnp.zeros([1, 28, 28, 1]) or jnp.random(jax.random.PRNGKey(42), [1, 28, 28, 1]), and it will give me the exactly same result.

My question is, can I use jnp.empty([1, 28, 28, 1]) instead? I want to use jnp.empty to clarify that we don't care about the value (and it could also be faster but the speedup is negligible). However, there is something called trap representation in C, and it looks like reading from jnp.empty without overwriting it first could trigger undefined behavior. Since Numpy is a light wrapper around C, should I worry about that?

Bonus question: let's forget about Jax and focus on vanilla Numpy. It is safe to read from np.empty([...])? Again, I don't care about the value, but I do care about not getting a segfault.

CodePudding user response:

Because XLA does not provide a mechanism to create uninitialized memory, in JAX jnp.empty is currently (v0.3.14) equivalent to jnp.zeros (see https://github.com/google/jax/blob/jax-v0.3.14/jax/_src/numpy/lax_numpy.py#L2007-L2009)

So at least in the current release, it is safe to refer to the contents of jnp.empty arrays. But if you're going to rely on that property, I'd suggest using jnp.zeros instead, so that if the jnp.empty implementation changes in the future your assumptions will still be valid.

np.empty is different: it will include uninitialized values, and so your program's behavior may change unpredictably from run to run if you rely on those values. There's no danger of memory corruption/segfaults when accessing these uninitialized values: the memory is allocated, it's just that the contents are uninitialized and so the values will reflect whatever bits happened to be stored there at the time the block was allocated.

  • Related