Home > database >  Is there a torch function to derive union of two tensors?
Is there a torch function to derive union of two tensors?

Time:10-07

There is a function that can be used to derive union of two tensors in numpy, as below:

import torch
import numpy as np
a = torch.tensor([0, 1, 2])
b = torch.tensor([2, 3, 4])
c = np.union1d(a, b) # c = array([0, 1, 2, 3, 4])
c = torch.from_numpy(c) # c = torch.tensor([0, 1, 2, 3, 4])

However, I am looking for torch function that can be used directly on two tensors. If I use numpy function as above, I must cast the result from numpy to torch, and also must use cpu although the function is applied to tensors.

Is there any union function in torch that can be used directly on two tensors? Or, at least, can it be simply implemented using other torch functions?

CodePudding user response:

You can use:

torch.cat((a, b)).unique()
  • Related