Home > Net >  Pytorch/Numpy: Subtract each of N elements from a single matrix, resulting in N matrices?
Pytorch/Numpy: Subtract each of N elements from a single matrix, resulting in N matrices?

Time:06-30

Question in the title. Is there an operation or way to broadcast to do this without looping? Here's a simple example with list comprehension:

image = torch.tensor([[6, 9], [8.7, 5.5]])
c = torch.tensor([5.7675, 8.8325])

# with list comprehension
desired_result = torch.stack([image - c_i for c_i in c])

# output:
tensor([[[ 0.2325,  3.2325],
         [ 2.9325, -0.2675]],

        [[-2.8325,  0.1675],
         [-0.1325, -3.3325]]])

I've tried reshaping the "scalar array" every which way to get the desired results with no luck.

CodePudding user response:

Not sure if torch has outer:

- np.subtract.outer(c.numpy(), image.numpy() )

Output:

array([[[ 0.23250008,  3.2325    ],
        [ 2.9325    , -0.26749992]],

       [[-2.8325005 ,  0.16749954],
        [-0.13250065, -3.3325005 ]]], dtype=float32)

In torch, you can flatten the two tensors and reshape:

-(c[:,None] - image.ravel()).reshape(*c.shape, *image.shape)

Output:

tensor([[[ 0.2325,  3.2325],
         [ 2.9325, -0.2675]],

        [[-2.8325,  0.1675],
         [-0.1325, -3.3325]]])
  • Related