I'm wondering if someone can explain this code to me?
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
It looks like pair_act is a 3D array and pair_mask is a 2D array? What are the numbers -1, -2, and -3? For 3D arrays, my initial thought is that the array is 0, column is 1, and row is 2. So where does the - sign come from? Any array examples would be appreciated. Thanks for the help.
CodePudding user response:
The documentation for jax.numpy.swapaxes
is here: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.swapaxes.html
The effect of swapaxes
is essentially to transpose the two provided axes, resulting in a differently shaped array:
import jax.numpy as jnp
x = jnp.arange(24).reshape((2, 3, 4))
print(x.shape)
# (2, 3, 4)
y = jnp.swapaxes(x, 1, 2)
print(y.shape)
# (2, 4, 3)
As is standard in numpy indexing, negative numbers count backward from the end; here the indices refer to entries in the shape (which has length 3), so -2, -1
is equivalent to 1, 2
:
y = jnp.swapaxes(x, -2, -1)
print(y.shape)
# (2, 4, 3)
The result of a swapaxes
is equivalent to an appropriately constructed transpose
operation:
y2 = jnp.transpose(x, (0, 2, 1))
print((y == y2).all())
# True