Home > database >  Nearest Neighbour difference in Numpy/PyTorch
Nearest Neighbour difference in Numpy/PyTorch

Time:10-18

I need a to write a custom loss function in PyTorch but owing to PyTorch's similarity to NumPy, a Numpy based solution will also work.

I have two tensors(Numpy arrays) p and q of shape (b,...). For each batch element in p I wish to compute the minimum difference w.r.t any of the batch element of q. Sample code is given below:

loss = 0
for outer in range(b):
   tmp_min = 1e 5
   for inner in range(b):
      tmp_loss = torch.abs(p[outer,...] - q[inner,...]) # np.abs(p[outer,...] - q[inner,...])
      if tmp_loss<tmp_min:
          tmp_min = tmp_loss
   loss = loss   tmp_min

Since I will have to compute this loss many many times, is there a way to do it without FOR and IF statements?

Finally, I wish to compute this loss in both directions. So, is there any alternative other than repeating the above code with p and q swapped?

CodePudding user response:

You can add a singleton dimension to p to leverage broadcasting when doing p-q: (p[:, None, ...] - q[:, ...]).shape == (B, B, ...).

It would work as follows:

def loss_vec(p, q):
    B = p.shape[0]
    assert q.shape[0] == B
    p = p.reshape(B, -1)
    q = q.reshape(B, -1)
    return (p[:, None, :] - q).abs().sum(axis=-1).min(axis=-1).values.sum()

def loss_op(p, q):
    """OP solution as oneliner"""
    return torch.tensor([min([torch.abs(x - y).sum() for y in q]) for x in p]).sum()


B, K, M, N = 11, 3, 5, 7
p = torch.rand(B, K, M, N)
q = torch.rand(B, K, M, N)

value_op_pq = loss_op(p, q)
value_vec_pq = loss_vec(p, q)

assert value_op_pq == value_vec_pq

To compute in both directions, just change the axis=... when computing the min:

def loss_vec_bi(p, q):
    """Returns the loss in both directions"""
    B = p.shape[0]
    assert q.shape[0] == B
    p = p.reshape(B, -1)
    q = q.reshape(B, -1)
    losses = (p[:, None, :] - q).abs().sum(axis=-1)
    return losses.min(axis=-1).values.sum(), losses.min(axis=0).values.sum()


value_op_qp = loss_op(q, p)
value_vec_pq, value_vec_qp = loss_vec_bi(p, q)

assert value_op_pq == value_vec_pq
assert value_op_qp == value_vec_qp
  • Related