Home > database >  Clustering 1D vector with window
Clustering 1D vector with window

Time:07-05

I am trying to identify clusters of 1s in a 1D vector. The problem I have is that the clusters that are separated by a number of zeros, that are less than a certain threshold, should be grouped together. Say, if I have two clusters separated by less than 3 zeros, they should be considered as one large cluster. For instance, the following vector:

[0,0,0,1,1,1,1,0,0,0,1,1,0,1,0,1,0,0,0,0,1,1,1]

should give me three clusters (with non-zero numbers indicating cluster ID):

[0,0,0,1,1,1,1,0,0,0,2,2,2,2,2,2,0,0,0,0,3,3,3]

I've been scratching my head for the entire day trying using rolling() in pandas and some custom-made functions, but can't come up with anything working.

CodePudding user response:

As you have tried, this can be done with pandas and rolling, although its a bit involved. It's possible that there is some simpler solution but the below should work at least.

First, create the dataframe from the data. We append 2 zeros (1 less than the minimum number of zeros between clusters) to the start and end of the dataframe to make sure that there is no cluster at the ends. This simplifies the later logic.

data = [0,0,0,1,1,1,1,0,0,0,1,1,0,1,0,1,0,0,0,0,1,1,1]
df = pd.DataFrame({'val': [0,0]   data   [0,0]})

Now, we compute the rolling sum on the dataframe with window size 3 and min_periods 3. This will give 2 nans at the start (for our padded zero values). We consider anywhere the sum is above 1 and divide the dataframe into groups depending on these values using diff and cumsum:

df['sum'] = df.rolling(window=3, min_periods=3).sum()
df['group'] = (df['sum'] > 0).diff().cumsum().bfill()

Any group with an odd number will be a cluster. But there are 2 extra rows marked for each. We remove these and assign the correct cluster id using a custom function with groupby and apply:

def adjust_group_value(x):
    group = x['group'].iloc[0]
    if group % 2 == 0:
        x['new_val'] = 0
    else:
        x['new_val'] = group // 2   1
        x['new_val'].iloc[-2:] = 0
    return x
     
df = df.groupby('group').apply(adjust_group_value)
df = df.iloc[2:-2]  # remove the padded 0s at the start and end

Result (all intermediate columns are kept to illustrate the process):

    val  sum  group  new_val
2     0  0.0    0.0      0.0
3     0  0.0    0.0      0.0
4     0  0.0    0.0      0.0
5     1  1.0    1.0      1.0
6     1  2.0    1.0      1.0
7     1  3.0    1.0      1.0
8     1  3.0    1.0      1.0
9     0  2.0    1.0      0.0
10    0  1.0    1.0      0.0
11    0  0.0    2.0      0.0
12    1  1.0    3.0      2.0
13    1  2.0    3.0      2.0
14    0  2.0    3.0      2.0
15    1  2.0    3.0      2.0
16    0  1.0    3.0      2.0
17    1  2.0    3.0      2.0
18    0  1.0    3.0      0.0
19    0  1.0    3.0      0.0
20    0  0.0    4.0      0.0
21    0  0.0    4.0      0.0
22    1  1.0    5.0      3.0
23    1  2.0    5.0      3.0
24    1  3.0    5.0      3.0

The final cluster ids can be obtained with df['new_val'].astype(int).values:

[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 3, 3, 3]

CodePudding user response:

Another solution, not requiring the use of .apply():

import pandas as pd

# Store the initial list in a pandas Series
ser = pd.Series([0,0,0,1,1,1,1,0,0,0,1,1,0,1,0,1,0,0,0,0,1,1,1])

First, identify and number each consecutive group of 1's and 0's with the size of that group:

grp_ser = ser.groupby((ser.diff() != 0).cumsum()).transform('size')
print(grp_ser.to_list())
# [3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3]

Using a copy of the original series, change value in rows where size of group is less than 3 and the original value is 0 to 1:

ser_copy = ser.copy()
ser_copy.loc[(grp_ser < 3) & ser.eq(0)] = 1
print(ser_copy.to_list())
# [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1]

At this point, all clusters have been identified and just remain to be numbered consecutively.

Create running sum that increments by 1 where 0's turn to 1's:

res = ((ser_copy.diff() != 0) & (ser_copy != 0)).cumsum()
print(res.to_list())
# [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3]

Replace 1's with 0's where the previous statement overrode the correct 0's:

res[ser_copy == 0] = 0
print(res.to_list())
# [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 3, 3, 3]
  • Related