I'm learning numpy and trying to find out a better way to code this example. I have an array indicating the cluster to which element (in other array) belong to. Element 0 belongs to cluster 1, element 1 to cluster 2 and so on and so forth. I'd like to create a cluster map, using a mask to indicate the elements belonging to such cluster. The code below works but I hate the two lines of tmp_mask, and wondering if they could be avoided.
cluster = np.array([1,2,1,1,2,3,1,2])
cluster_map = {}
empty_mask = np.zeros(len(cluster), dtype=bool)
for idx, cl in enumerate(cluster):
tmp_mask = empty_mask.copy()
tmp_mask[idx] = True
cluster_map[cl] = cluster_map.get(cl, empty_mask) | tmp_mask
cluster_map
Just trying to see if there is a shorter version, for example:
#tmp_mask = empty_mask.copy()
#tmp_mask[idx] = True
cluster_map[cl] = cluster_map.get(cl, empty_mask) | get_falses_except(idx, len(cluster))
I know I can create function get_falses_except, just wondering if it exists, or the code can be rewritten in a better way?
Thank you all
CodePudding user response:
Advanced indexing might help you:
empty_mask = np.zeros((np.max(cluster), len(cluster)), dtype=bool)
empty_mask[cluster-1, np.arange(len(cluster))] = True
>>> empty_mask
array([[ True, False, True, True, False, False, True, False],
[False, True, False, False, True, False, False, True],
[False, False, False, False, False, True, False, False]])
You can also return dict
if you need:
>>> dict(zip(range(1, shape[0] 1), empty_mask))
{1: array([ True, False, True, True, False, False, True, False]),
2: array([False, True, False, False, True, False, False, True]),
3: array([False, False, False, False, False, True, False, False])}
CodePudding user response:
You can do it with a pretty straightforward comprehension:
In [8]: {k: cluster==k for k in set(cluster)}
Out[8]:
{1: array([ True, False, True, True, False, False, True, False]),
2: array([False, True, False, False, True, False, False, True]),
3: array([False, False, False, False, False, True, False, False])}
CodePudding user response:
i think the for
loop can be skipped too but i don't know how
import numpy as np
cluster = np.array([1,2,1,1,2,3,1,2])
cluster_set=np.unique(cluster)
cluster_map=np.zeros((cluster_set.size, cluster.size), dtype=bool)
for idx, val in enumerate(cluster_set):
cluster_map[idx]=cluster==val
print(cluster_map)