Home > Blockchain >  How to count the number of different elements in a 3D tensor (or array) fast?
How to count the number of different elements in a 3D tensor (or array) fast?

Time:07-27

Input a 3D tensor (or numpy array), say size ([2, 2, 3])

T = torch.tensor([[[1,1,1],[1,2,1]],[[2,2,2],[1,4,5]]])

tensor([[[1, 1, 1],
         [1, 2, 1]],
        [[2, 2, 2],
         [1, 4, 5]]])

I want to count the number of different elements in each row and expect return:

tensor([[[1],
         [2]],
        [[1],
         [3]]])

I am now using

color_count = torch.zeros((T.shape[0], T.shape[1], 1),dtype = int)
for i in range(T.shape[0]):
    for j in range(T.shape[1]):
        count = len(T[i][j,:].unique())
        color_count[i][j][0] = count

But it's too slow as I have to do it many times. Anyone could help to improve the speed please.

CodePudding user response:

If numpy solution suits you, numpy has a faster alternative to loops: np.apply_along_axis

# assuming `T` is a 3d numpy array: 
np.apply_along_axis(lambda row: [len(np.unique(row))], axis=2, arr=T)

Also note that set(row) may work faster than np.unique(row) for small arrays.

  • Related