Home > Enterprise >  Find out which rows of one 2D numpy array are represented in another 2D numpy array
Find out which rows of one 2D numpy array are represented in another 2D numpy array

Time:11-07

I have two arrays :

a = [[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9],
     [10, 11, 12],
     [13, 14, 15]]

b = [[1, 2, 3],
     [4, 5, 6],
     [13, 14, 15]]

And I want to find out which rows of first array are represented in second array

desired_output = [1, 1, 0, 0, 1]

I have tried this code :

x = a == b[:, None] 
row_sums = da.sum(x, axis=2)
output = np.sum(np.where(row_sums == 6,1,0),axis=0)

But it creates a massive 3D array - x - which is shaped (a(rows), b(rows), a (or b) (columns)).

x.shape() = [5,3,3]

And taking into account that my arrays are large, my computer will take a long time to compute it. Does someone have ideas how to improve my code?

CodePudding user response:

You can compare each row separately and then merge results using or. Here is sample code:

np.logical_or.reduce([(a == b[i]).all(axis=1) for i in range(b.shape[0])])

Result:

array([ True,  True, False, False,  True])

CodePudding user response:

Here's a solution that is O(len(a)) O(len(b)), rather than O(len(a) * len(b)):

set_b = {tuple(row) for row in b}
[tuple(row) in set_b for row in a]

Result:

[True, True, False, False, True]

This relies on tuples being immutable, so you can use them as members of a set (or keys in a dict). The first line is O(len(b)) and the second line is O(len(a)).

  • Related