Home > Blockchain >  Custom Operations on Multi-dimensional Tensors
Custom Operations on Multi-dimensional Tensors

Time:08-03

I am trying to compute the tensor R (see image) and the only way I could explain what I am trying to compute is by doing it on a paper:

enter image description here

o = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50], [55, 57, 56]]])
p = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])

# this is O' in image
o_prime = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 0.11]])

# this is P' in image
p_prime = torch.tensor([[1.1, 1.2, 1.3, 1.4, 1.5], [1.6, 1.7, 1.8, 1.9, 1.11]])

# this is R (this is what I need)
r = torch.tensor([[[0, 0, 0, 6.1, 0], [0, 24.2, 0, 0, 0], [0, 0, 42.3, 0, 0], [60.4, 0, 0, 0, 0], [0, 0, 0, 0, 78.5]], [[0, 96.6, 0, 0, 0], [0, 0, 114.7, 0, 0], [0, 0, 0, 132.8, 0], [0, 0, 0, 0, 150.9], [168.11, 0, 0, 0, 0]]])

How do I get R without looping over tensors?

correction: In the image, I forgot to add value of p' along with sum(o) o'

CodePudding user response:

You can construct a helper tensor containing the resulting values sum(o) o' p':

>>> v = o.sum(2, True)   o_prime[...,None]   o_prime[...,None]
tensor([[[  7.2000],
         [ 25.4000],
         [ 43.6000],
         [ 61.8000],
         [ 80.0000]],

        [[ 98.2000],
         [116.4000],
         [134.6000],
         [152.8000],
         [169.2200]]])

Then you can assemble a mask for the final tensor via broadcasting:

>>> eq = o[:,None] == p[:,:,None]

Ensuring all three elements on the last dimension match:

>>> eq.all(dim=-1)
tensor([[[False, False, False,  True, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [ True, False, False, False, False],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [ True, False, False, False, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [False, False, False,  True, False]]])

Finally, you can simply multiply both tensor and auto-broadcasting will handle the rest:

>>> R = eq.all(dim=-1) * v
tensor([[[  0.0000,   0.0000,   0.0000,   7.2000,   0.0000],
         [  0.0000,  25.4000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,  43.6000,   0.0000,   0.0000],
         [ 61.8000,   0.0000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,   0.0000,  80.0000]],

        [[  0.0000,   0.0000,   0.0000,   0.0000,  98.2000],
         [116.4000,   0.0000,   0.0000,   0.0000,   0.0000],
         [  0.0000, 134.6000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000, 152.8000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000, 169.2200,   0.0000]]])

CodePudding user response:

You can compute a mask (2,5,5) that tells you where o_i==p_j:

mask = (o.repeat(1,1,p.shape[1]).reshape(o.shape[0],o.shape[1],p.shape[1],o.shape[2])==p.reshape(p.shape[0],1,p.shape[1],p.shape[2])).all(-1)

mask is:

tensor([[[False, False, False,  True, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [ True, False, False, False, False],
         [False, False, False, False,  True]],

        [[False,  True, False, False, False],
         [False, False,  True, False, False],
         [False, False, False,  True, False],
         [False, False, False, False,  True],
         [ True, False, False, False, False]]])

then you can compute the indices like:

ids = torch.argwhere(mask)

ids tells you which element of o and p consider for the equation:

tensor([[0, 0, 3],
        [0, 1, 1],
        [0, 2, 2],
        [0, 3, 0],
        [0, 4, 4],
        [1, 0, 1],
        [1, 1, 2],
        [1, 2, 3],
        [1, 3, 4],
        [1, 4, 0]])

Finally you can defy R = torch.zeros(mask.shape) and fill it with the equation:

R[mask] = o[ids[:,0], ids[:,1]].sum(1) o_prime[ids[:,0], ids[:,1]] p_prime[[ids[:,0], ids[:,1]]]

output R is:

tensor([[[  0.0000,   0.0000,   0.0000,   7.2000,   0.0000],
         [  0.0000,  25.4000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,  43.6000,   0.0000,   0.0000],
         [ 61.8000,   0.0000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,   0.0000,  80.0000]],

        [[  0.0000,  98.2000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000, 116.4000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000, 134.6000,   0.0000],
         [  0.0000,   0.0000,   0.0000,   0.0000, 152.8000],
         [169.2200,   0.0000,   0.0000,   0.0000,   0.0000]]])
  • Related