Home > database >  Get a mask of false except for specific index
Get a mask of false except for specific index

Time:11-20

I'm learning numpy and trying to find out a better way to code this example. I have an array indicating the cluster to which element (in other array) belong to. Element 0 belongs to cluster 1, element 1 to cluster 2 and so on and so forth. I'd like to create a cluster map, using a mask to indicate the elements belonging to such cluster. The code below works but I hate the two lines of tmp_mask, and wondering if they could be avoided.

cluster = np.array([1,2,1,1,2,3,1,2])
cluster_map = {}

empty_mask = np.zeros(len(cluster), dtype=bool)

for idx, cl in enumerate(cluster):
    tmp_mask = empty_mask.copy() 
    tmp_mask[idx] = True
    cluster_map[cl] = cluster_map.get(cl, empty_mask) | tmp_mask

cluster_map

Just trying to see if there is a shorter version, for example:

    #tmp_mask = empty_mask.copy() 
    #tmp_mask[idx] = True
    cluster_map[cl] = cluster_map.get(cl, empty_mask) | get_falses_except(idx, len(cluster))

I know I can create function get_falses_except, just wondering if it exists, or the code can be rewritten in a better way?

Thank you all

CodePudding user response:

Advanced indexing might help you:

empty_mask = np.zeros((np.max(cluster), len(cluster)), dtype=bool)
empty_mask[cluster-1, np.arange(len(cluster))] = True
>>> empty_mask
array([[ True, False,  True,  True, False, False,  True, False],
       [False,  True, False, False,  True, False, False,  True],
       [False, False, False, False, False,  True, False, False]])

You can also return dict if you need:

>>> dict(zip(range(1, shape[0] 1), empty_mask))
{1: array([ True, False,  True,  True, False, False,  True, False]),
 2: array([False,  True, False, False,  True, False, False,  True]),
 3: array([False, False, False, False, False,  True, False, False])}

CodePudding user response:

You can do it with a pretty straightforward comprehension:

In [8]: {k: cluster==k for k in set(cluster)}
Out[8]:
{1: array([ True, False,  True,  True, False, False,  True, False]),
 2: array([False,  True, False, False,  True, False, False,  True]),
 3: array([False, False, False, False, False,  True, False, False])}

CodePudding user response:

i think the for loop can be skipped too but i don't know how

import numpy as np
cluster = np.array([1,2,1,1,2,3,1,2])

cluster_set=np.unique(cluster)
cluster_map=np.zeros((cluster_set.size, cluster.size), dtype=bool)
for idx, val in enumerate(cluster_set):
    cluster_map[idx]=cluster==val

print(cluster_map)
  • Related