Home > Mobile >  InvalidArgumentError: Incompatible shapes: [32,128] vs. [128,128] [Op:BroadcastTo]
InvalidArgumentError: Incompatible shapes: [32,128] vs. [128,128] [Op:BroadcastTo]

Time:08-31

I want to broadcast my tensor of size (32, 128) to (128, 128) but it generates this error. InvalidArgumentError: Incompatible shapes: [32,128] vs. [128,128] [Op:BroadcastTo].

I want to know whether it's possible to broadcast it or not. what are the possible reasons to perform this. I use the following code

loss = tf.broadcast_to(kl_loss, [128, 128])

CodePudding user response:

We can use tf.broadcast_to, But we need to consider that first we need to convert 2D tensor to 3D tensor and then reshape it to 2D like below.

import tensorflow as tf

kl_loss = tf.random.uniform((32, 128))
print(kl_loss.shape)
# (32, 128)

result = tf.reshape(tf.broadcast_to(kl_loss, [4 , 32, 128]), [128, 128])
# ------------------------------------------  ^4 : 128//32
print(result.shape)
# (128, 128)

What will the output look like? (Example for broadcasting from (2, 3) to (4, 3))

tf.random.set_seed(123)

x = tf.random.uniform((2, 3))
print(x)

y = tf.broadcast_to(x, [2, 2, 3])
print(y)

z = tf.reshape(y, [4, 3])
print(z)

# x =>
tf.Tensor(
[[0.12615311 0.5727513  0.2993133 ]
 [0.5461836  0.7205157  0.7889533 ]], shape=(2, 3), dtype=float32)

# y =>
tf.Tensor(
[[[0.12615311 0.5727513  0.2993133 ]
  [0.5461836  0.7205157  0.7889533 ]]

 [[0.12615311 0.5727513  0.2993133 ]
  [0.5461836  0.7205157  0.7889533 ]]], shape=(2, 2, 3), dtype=float32)

# z =>
tf.Tensor(
[[0.12615311 0.5727513  0.2993133 ]
 [0.5461836  0.7205157  0.7889533 ]
 [0.12615311 0.5727513  0.2993133 ]
 [0.5461836  0.7205157  0.7889533 ]], shape=(4, 3), dtype=float32)

CodePudding user response:

You need to some how upsample your tensor. Maybe try tf.tile or tf.repeat or tf.concat:

import tensorflow as tf

kl_loss = tf.random.normal((32, 128))
print(tf.repeat(kl_loss, repeats = int(128 / tf.shape(kl_loss)[0]), axis=0).shape)
print(tf.tile(kl_loss, multiples=[int(128 / tf.shape(kl_loss)[0]), 1]).shape)
print(tf.concat([kl_loss, kl_loss, kl_loss, kl_loss], axis=0).shape)
(128, 128)
(128, 128)
(128, 128)
  • Related