Home > front end >  Update tensor indices with provided values for batches
Update tensor indices with provided values for batches

Time:09-28

I'm trying to update rank3 tensor using tensor_scatter_nd_update fn. For rank2 tensor I used the following snippet:

tensor = tf.zeros((3, 2))
indices = [[0], [2]]          
updates = [[5, 5], [10, 10]]       
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)
tf.Tensor(
[[ 5.  5.]
 [ 0.  0.]
 [10. 10.]], shape=(3, 2), dtype=float32)

Now I want to perform the same op for batches -> basically on rank3 tensor:

tensor = tf.zeros((4, 3, 2))
indices = [[[0], [2]],
           [[1], [0]],
           [[0], [1]],
           [[0], [2]]]       
updates = [[[5, 5], [10, 10]], 
           [[1, 1], [7, 7]],
           [[3, 3], [2, 2]],
           [[5, 5], [1, 1]]]    
output = tf.tensor_scatter_nd_update(tensor, indices, updates)

However, I'm getting a shape mismatch error:

Inner dimensions of output shape must match inner dimensions of updates shape Output: [4,3,2] updates: [4,2,2]

CodePudding user response:

It is simply as Matrix multiplication that follows the sample it reveals your works, and the index indicates the target.

  1. As in programming tools built carefully it will consider the orders and the number of data contain as a basic program.
  2. The outer and inner is indicates the multiplying matrix dimensions as we learn in mathematics and science programs. Sample: ( 8, 2 ) * ( 2 * 8 )
  3. Specify types of inputs, index can be int, scatters expecting float32 since the calculation of Eigaint and their relationship.
  4. For more than 2 dimensions you had options to convert it back to 2 dimensions or vary each index input, see some example create function and loop though the indexes.

Sample:: Simple solution

tensor = tf.zeros((8, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )
indices = [[[0], [2]],
           [[1], [0]],
           [[0], [1]],
           [[0], [2]]] 
indices = tf.constant(indices, shape=(8, 1), dtype=tf.int32)

updates = [[[5, 5], [10, 10]], 
           [[1, 1], [7, 7]],
           [[3, 3], [2, 2]],
           [[5, 5], [1, 1]]]  
updates = tf.constant(updates, shape=(8, 2), dtype=tf.int32)    
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)   

Output

tf.Tensor(
[[5 5]
 [2 2]
 [1 1]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]], shape=(8, 2), dtype=int32)

Sample 2:: Transform input is easiest way

tensor = tf.zeros((4, 3, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )
tensor = tf.constant(tensor, shape=(12, 2), dtype=tf.int32)
indices = [[[0], [2], [1]],
           [[1], [0], [1]],
           [[0], [1], [1]],
           [[0], [2], [1]]] 
indices = tf.constant(indices, shape=(12, 1), dtype=tf.int32)

updates = [[[5, 5, 1], [10, 10, 1]], 
           [[1, 1, 1], [7, 7, 1]],
           [[3, 3, 1], [2, 2, 1]],
           [[5, 5, 1], [1, 1, 1]]]  
updates = tf.constant(updates, shape=(12, 2), dtype=tf.int32)   
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
print(output)

Output

tf.Tensor(
[[5 5]
 [1 1]
 [1 1]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]], shape=(12, 2), dtype=int32)

Sample 3:: Repeating where you can create recursive

global index
index = 5

@tf.function
def reverse_count( index_input ):
    global index
    index = index - 1
    
    if index > 0 :
        return True
    else :
        return False


tensor = tf.zeros((3, 2))
tensor = tf.cast( tensor, dtype=tf.int32 )

index = 5
indices = [[[0], [2], [1]],
           [[1], [0], [1]],
           [[0], [1], [1]],
           [[0], [2], [1]]] 
indices = tf.constant(indices, shape=(4, 3, 1), dtype=tf.int32)

updates = [[5, 5], [10, 10], [1, 1]]
updates = tf.constant(updates, shape=(3, 2), dtype=tf.int32)    

print( index )
indice = tf.where([reverse_count( index )], indices[index - 1,:,:1], tf.constant([[-1, -1, -1]], shape=(3, 1)).numpy() )
print( 'tensor' )
print( tensor )
print( 'indice' )
print( indice )
print( 'updates' )
print( updates )
output = tf.tensor_scatter_nd_update(tensor, indice, updates)
print( 'output' )
print( output )

Output

    5
tensor
tf.Tensor(
[[0 0]
 [0 0]
 [0 0]], shape=(3, 2), dtype=int32)
indice
tf.Tensor(
[[0]
 [2]
 [1]], shape=(3, 1), dtype=int32)
updates
tf.Tensor(
[[ 5  5]
 [10 10]
 [ 1  1]], shape=(3, 2), dtype=int32)
output
tf.Tensor(
[[ 5  5]
 [ 1  1]
 [10 10]], shape=(3, 2), dtype=int32)

CodePudding user response:

If you want to update by the first dimension, for example you want to change [0], [2] and [3], then each of the updates has shape (1, 3, 2) but you have three updates so (3, 3, 2):

tensor = tf.zeros((4, 3, 2))
indices = [[0], [2], [3]] # shape (3,  1)

updates = [[[1, 1], [2, 2], [3, 3]], # <- this is [0] and will go to output[0,:,:]
          [[10, 10], [10, 10], [10, 10]], # <- this is [2]
          [[101,102], [103, 104], [105, 106]]] # shape (3,  3, 2)

output = tf.tensor_scatter_nd_update(tensor, indices, updates)

output:

tf.Tensor(
[[[  1.   1.]
  [  2.   2.]
  [  3.   3.]]

 [[  0.   0.]
  [  0.   0.]
  [  0.   0.]]

 [[ 10.  10.]
  [ 10.  10.]
  [ 10.  10.]]

 [[101. 102.]
  [103. 104.]
  [105. 106.]]], shape=(4, 3, 2), dtype=float32)

If you want to update by the first two dimensions, for example you want to update the positions [0, 1] and [2, 2], then each of updates has shape (1, 1, 2), but you have two updates, so updates will have shape (2, 1, 2):

tensor = tf.zeros((4, 3, 2))
indices = [[[0,1]], [[2,2]]] # shape (2, 1,  2)

updates = [[[1, 1]], # <- this will go to output[0,1,:]
          [[10, 10]]] # shape (2, 1, 2)

output = tf.tensor_scatter_nd_update(tensor, indices, updates)

output:

tf.Tensor(
[[[ 0.  0.]
  [ 1.  1.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 0.  0.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 0.  0.]
  [10. 10.]]

 [[ 0.  0.]
  [ 0.  0.]
  [ 0.  0.]]], shape=(4, 3, 2), dtype=float32)
  • Related