So I have two List[List] X,Y that I need to filter given a List[List] Z. The values in Y correspond to the scores of elements in X. I need to check if values of X belong to the filter in Z and keep the scores in Y corresponding to those values.
I will illustrate with an example and my current solution.
# Base matrixes
X = [[2,1,3,4],
[1,3,2,4],
[1,2,3,4]]
Y = [[0.2,0.1,0.9,1.0],
[0.3,0.2,0.4,0.2],
[0.8,0.6,0.5,0.2]]
Z = [[1,2,3,4],
[2,3],
[1]]
# Expected results
new_x = [[2,1,3,4],
[3,2],
[1]]
new_y = [[0.2,0.1,0.9,1.0],
[0.2,0.4],
[0.8]]
# Current solution
def find_idx(a,b):
r = []
for idx, sub_a in enumerate(a):
if sub_a in b:
r =[idx]
return r
def filter(X, Y, Z):
X = np.asarray(X)
Y = np.asarray(Y)
Z = np.asarray(Z)
assert len(X)==len(Y)==len(Z)
r_x = []
r_y = []
for idx, sub_filter in enumerate(Z):
x = find_idx(X[idx], Z[idx])
r_x.append(X[idx][x].tolist())
r_y.append(Y[idx][x].tolist())
return r_x, r_y
r_x, r_y = filter(X,Y,Z)
I figured out I could easily do this with a collection of list comprehensions, but performance is important for this function.
Is there any way to speed up the part where i find the indexes of values of X that are in Z to later filter X,Y by them?
CodePudding user response:
This is a more efficient way to do that when the input matrices are big:
X = np.array([
[2, 1, 3, 4],
[1, 3, 2, 4],
[1, 2, 3, 4],
])
Y = np.array([
[0.2, 0.1, 0.9, 1.0],
[0.3, 0.2, 0.4, 0.2],
[0.8, 0.6, 0.5, 0.2],
])
Z = [[1, 2, 3, 4],
[2, 3],
[1],
]
mask = np.zeros(X.shape)
new_x = []
new_y = []
for i, z_row in enumerate(Z):
mask = np.isin(X[i], z_row)
new_x.append(X[i][mask].tolist())
new_y.append(Y[i][mask].tolist())
It was roughly 10x faster than the list comprehension when I tested it with 5000x5000 matrices. This is because the list comprehension has to loop over all elements of the list z
when using the in
operator.
CodePudding user response:
Using a nested list comprehension:
x_new = [[x[i-1] for i in z] for x,z in zip(X, Z)]
Output:
[[2, 1, 3, 4],
[3, 2],
[1]]
CodePudding user response:
new_x = []
new_y = []
zipped_xy = [list(zip(*el)) for el in zip(X, Y)]
for idx, v in enumerate(Z):
temp_x = []
temp_y = []
for x, y in zipped_xy[idx]:
if x in v:
temp_x.append(x)
temp_y.append(y)
new_x.append(temp_x)
new_y.append(temp_y)
print(new_x)
print(new_y)
# [[2, 1, 3, 4], [3, 2], [1]]
# [[0.2, 0.1, 0.9, 1.0], [0.2, 0.4], [0.8]]