Home > database >  Pytorch: Test each row of the first 2D tensor also exist in the second tensor?
Pytorch: Test each row of the first 2D tensor also exist in the second tensor?

Time:08-31

Given two tensors t1 and t2:

t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])

If the row elements of t1 is exist in t2, return True, otherwise return False. The ideal result is [Ture, False, True]. I tried torch.isin(t1, t2), but its return the results by elements not by rows. By the way, if they are numpy arrays, it can be completed by

np.in1d(t1.view('i,i').reshape(-1), t2.view('i,i').reshape(-1))

I wonder how to get the similar result in tensor?

CodePudding user response:

def rowwise_in(a,b):
  """ 
  a - tensor of size a0,c
  b - tensor of size b0,c
  returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
  """
  
  # dimensions
  a0 = a.shape[0]
  b0 = b.shape[0]
  c  = a.shape[1]
  assert c == b.shape[1] , "Tensors must have same number of columns"

  a_expand = a.unsqueeze(1).expand(a0,b0,c)
  b_expand = b.unsqueeze(0).expand(a0,b0,c)

  # element-wise equality
  equal = a_expand == b_expand

  # sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
  row_equal = torch.prod(equal,dim = 2)

  row_in_b = torch.max(row_equal, dim = 1)[0]
  return row_in_b
  • Related