Home > front end >  Split a list (numpy array, etc.) into groups of a given length, but keeping the same elements within
Split a list (numpy array, etc.) into groups of a given length, but keeping the same elements within

Time:03-21

I would like to split a list (numpy array, etc.) into groups of a particular length, but keeping the same elements within the same group, even if group size becomes larger than the specified length.

e.g. if we want groups of size = 5 for the following list:

>>> import numpy as np
>>> input = np.array([5] * 19   [1]   [4] * 4   [2] * 4   [3] * 2)

array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 4, 4,
       4, 4, 2, 2, 2, 2, 3, 3])

The output would be preferably a list of indices of the grouped elements:

>>> [np.where(x==5), np.where((x == 1) | (x == 4)), \
np.where(x==2), np.where(x==3),]

[(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18]),),
 (array([19, 20, 21, 22, 23]),),
 (array([24, 25, 26, 27]),),
 (array([28, 29]),)]

The code would need to be efficient enough to process the list of 50k elements in a reasonable amount of time.

Do you have any ideas on how to proceed? I was trying to experiment with pandas groupby, numpy split and similar functions, but was not able to come up with a viable algorithm.

CodePudding user response:

You can solve this problem by sorting the array, keeping the sort ordering, and then splitting the array based on the differences between values in the sorted array.

# Sort the input array and return the indices (so that x[indices] is sorted)
indices = np.argsort(x, kind='stable')
sorted_x = x[indices]

# Find the index of the first item of each group
groupStartIndex = np.where(np.insert(sorted_x[1:] != sorted_x[:-1], 0, 0))[0]

# Split the indices in groups
result = np.split(indices, groupStartIndex)

The result is the following:

[array([19]),
 array([24, 25, 26, 27]),
 array([28, 29]),
 array([20, 21, 22, 23]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18])]

The group index arrays are sorted by the value associated to each group (eg. 1 2 3 4 5). You can find the associated values (aka keys) using:

group_values = sorted_x[np.insert(groupStartIndex, 0, 0)]

Note that is the items are already packed in the input array (eg. no [5, 5, 1, 5]), then you can even skip the np.argsort and use np.arange(x.size) for the indices.

  • Related