Home > database >  Group elements in ndarray by index
Group elements in ndarray by index

Time:01-31

I have an image dataset of a 1000 images, which I have created embeddings for. Each embeddings (512 embeddings for each image with a 256-d vector) is an ndarray of shape (512, 256), so the total array shape would be (1000, 512, 256).

Now, from each image (1000), I want to create a group of observation for the first embedding, of the 512 available, and collecting this embedding from each image. Then I want to do this for the second embedding, third, fourth, up to the 512th.

How would I go about creating these groups?

CodePudding user response:

You can achieve that as follows:

groups = []

for i in range(512):
    # Select the i-th embedding from each image
    group = embeddings[:, i, :]
    groups.append(group)

groups = np.array(groups)

Another optimized solution:

groups = np.array([embeddings[:, i, :] for i in range(512)])
groups = np.transpose(groups, (1, 0, 2))
  • Related