Home > database >  swapaxes and how it is implemented?
swapaxes and how it is implemented?

Time:02-11

I'm wondering if someone can explain this code to me?

c = self.config

assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']

if c.orientation == 'per_column':
  pair_act = jnp.swapaxes(pair_act, -2, -3)
  pair_mask = jnp.swapaxes(pair_mask, -1, -2)

It looks like pair_act is a 3D array and pair_mask is a 2D array? What are the numbers -1, -2, and -3? For 3D arrays, my initial thought is that the array is 0, column is 1, and row is 2. So where does the - sign come from? Any array examples would be appreciated. Thanks for the help.

CodePudding user response:

The documentation for jax.numpy.swapaxes is here: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.swapaxes.html

The effect of swapaxes is essentially to transpose the two provided axes, resulting in a differently shaped array:

import jax.numpy as jnp

x = jnp.arange(24).reshape((2, 3, 4))
print(x.shape)
# (2, 3, 4)

y = jnp.swapaxes(x, 1, 2)
print(y.shape)
# (2, 4, 3)

As is standard in numpy indexing, negative numbers count backward from the end; here the indices refer to entries in the shape (which has length 3), so -2, -1 is equivalent to 1, 2:

y = jnp.swapaxes(x, -2, -1)
print(y.shape)
# (2, 4, 3)

The result of a swapaxes is equivalent to an appropriately constructed transpose operation:

y2 = jnp.transpose(x, (0, 2, 1))
print((y == y2).all())
# True
  • Related