I'm trying to update rank3 tensor using tensor_scatter_nd_update fn. For rank2 tensor I used the following snippet:
tensor = tf.zeros((3, 2))
indices = [[0], [2]]
updates = [[5, 5], [10, 10]]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)
tf.Tensor(
[[ 5. 5.]
[ 0. 0.]
[10. 10.]], shape=(3, 2), dtype=float32)
Now I want to perform the same op for batches -> basically on rank3 tensor:
tensor = tf.zeros((4, 3, 2))
indices = [[[0], [2]],
[[1], [0]],
[[0], [1]],
[[0], [2]]]
updates = [[[5, 5], [10, 10]],
[[1, 1], [7, 7]],
[[3, 3], [2, 2]],
[[5, 5], [1, 1]]]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
However, I'm getting a shape mismatch error:
Inner dimensions of output shape must match inner dimensions of updates shape Output: [4,3,2] updates: [4,2,2]
CodePudding user response:
It is simply as Matrix multiplication that follows the sample it reveals your works, and the index indicates the target.
- As in programming tools built carefully it will consider the orders and the number of data contain as a basic program.
- The outer and inner is indicates the multiplying matrix dimensions as we learn in mathematics and science programs. Sample: ( 8, 2 ) * ( 2 * 8 )
- Specify types of inputs, index can be int, scatters expecting float32 since the calculation of Eigaint and their relationship.
- For more than 2 dimensions you had options to convert it back to 2 dimensions or vary each index input, see some example create function and loop though the indexes.
Sample:: Simple solution
tensor = tf.zeros((8, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )
indices = [[[0], [2]],
[[1], [0]],
[[0], [1]],
[[0], [2]]]
indices = tf.constant(indices, shape=(8, 1), dtype=tf.int32)
updates = [[[5, 5], [10, 10]],
[[1, 1], [7, 7]],
[[3, 3], [2, 2]],
[[5, 5], [1, 1]]]
updates = tf.constant(updates, shape=(8, 2), dtype=tf.int32)
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)
Output
tf.Tensor(
[[5 5]
[2 2]
[1 1]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]], shape=(8, 2), dtype=int32)
Sample 2:: Transform input is easiest way
tensor = tf.zeros((4, 3, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )
tensor = tf.constant(tensor, shape=(12, 2), dtype=tf.int32)
indices = [[[0], [2], [1]],
[[1], [0], [1]],
[[0], [1], [1]],
[[0], [2], [1]]]
indices = tf.constant(indices, shape=(12, 1), dtype=tf.int32)
updates = [[[5, 5, 1], [10, 10, 1]],
[[1, 1, 1], [7, 7, 1]],
[[3, 3, 1], [2, 2, 1]],
[[5, 5, 1], [1, 1, 1]]]
updates = tf.constant(updates, shape=(12, 2), dtype=tf.int32)
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)
Output
tf.Tensor(
[[5 5]
[1 1]
[1 1]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]], shape=(12, 2), dtype=int32)
Sample 3:: Repeating where you can create recursive
global index
index = 5
@tf.function
def reverse_count( index_input ):
global index
index = index - 1
if index > 0 :
return True
else :
return False
tensor = tf.zeros((3, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )
index = 5
indices = [[[0], [2], [1]],
[[1], [0], [1]],
[[0], [1], [1]],
[[0], [2], [1]]]
indices = tf.constant(indices, shape=(4, 3, 1), dtype=tf.int32)
updates = [[5, 5], [10, 10], [1, 1]]
updates = tf.constant(updates, shape=(3, 2), dtype=tf.int32)
print( index )
indice = tf.where([reverse_count( index )], indices[index - 1,:,:1], tf.constant([[-1, -1, -1]], shape=(3, 1)).numpy() )
print( 'tensor' )
print( tensor )
print( 'indice' )
print( indice )
print( 'updates' )
print( updates )
output = tf.tensor_scatter_nd_update(tensor, indice, updates)
print( 'output' )
print( output )
Output
5
tensor
tf.Tensor(
[[0 0]
[0 0]
[0 0]], shape=(3, 2), dtype=int32)
indice
tf.Tensor(
[[0]
[2]
[1]], shape=(3, 1), dtype=int32)
updates
tf.Tensor(
[[ 5 5]
[10 10]
[ 1 1]], shape=(3, 2), dtype=int32)
output
tf.Tensor(
[[ 5 5]
[ 1 1]
[10 10]], shape=(3, 2), dtype=int32)
CodePudding user response:
If you want to update by the first dimension, for example you want to change [0]
, [2]
and [3]
, then each of the updates has shape (1, 3, 2)
but you have three updates so (3, 3, 2)
:
tensor = tf.zeros((4, 3, 2))
indices = [[0], [2], [3]] # shape (3, 1)
updates = [[[1, 1], [2, 2], [3, 3]], # <- this is [0] and will go to output[0,:,:]
[[10, 10], [10, 10], [10, 10]], # <- this is [2]
[[101,102], [103, 104], [105, 106]]] # shape (3, 3, 2)
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
output:
tf.Tensor(
[[[ 1. 1.]
[ 2. 2.]
[ 3. 3.]]
[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]]
[[ 10. 10.]
[ 10. 10.]
[ 10. 10.]]
[[101. 102.]
[103. 104.]
[105. 106.]]], shape=(4, 3, 2), dtype=float32)
If you want to update by the first two dimensions, for example you want to update the positions [0, 1]
and [2, 2]
, then each of updates has shape (1, 1, 2)
, but you have two updates, so updates
will have shape (2, 1, 2)
:
tensor = tf.zeros((4, 3, 2))
indices = [[[0,1]], [[2,2]]] # shape (2, 1, 2)
updates = [[[1, 1]], # <- this will go to output[0,1,:]
[[10, 10]]] # shape (2, 1, 2)
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
output:
tf.Tensor(
[[[ 0. 0.]
[ 1. 1.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]
[10. 10.]]
[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]]], shape=(4, 3, 2), dtype=float32)