Home > OS >  Merge lists in a dataframe column if they share a common value
Merge lists in a dataframe column if they share a common value

Time:09-23

What I need:

I have a dataframe where the elements of a column are lists. There are no duplications of elements in a list. For example, a dataframe like the following:

import pandas as pd

>>d = {'col1': [[1, 2, 4, 8], [15, 16, 17], [18, 3], [2, 19], [10, 4]]}
>>df = pd.DataFrame(data=d)

           col1
0  [1, 2, 4, 8]
1  [15, 16, 17]
2       [18, 3]
3       [2, 19]
4       [10, 4]

I would like to obtain a dataframe where, if at least a number contained in a list at row i is also contained in a list at row j, then the two list are merged (without duplication). But the values could also be shared by more than two lists, in that case I want all lists that share at least a value to be merged.

                   col1
0  [1, 2, 4, 8, 19, 10]
1          [15, 16, 17]
2               [18, 3]

The order of the rows of the output dataframe, nor the values inside a list is important.


What I tried:

I have found this networkx graph list merging

You can thus:

  • generate successive edges with add_edges_from
  • find the connected_components
  • craft a dictionary and map the first item of each list
  • groupby and merge the lists (you could use the connected components directly but I'm giving a pandas solution in case you have more columns to handle)
import networkx as nx

G = nx.Graph()
for l in df['col1']:
    G.add_edges_from(zip(l, l[1:]))

groups = {k:v for v,l in enumerate(nx.connected_components(G)) for k in l}
# {1: 0, 2: 0, 4: 0, 8: 0, 10: 0, 19: 0, 16: 1, 17: 1, 15: 1, 18: 2, 3: 2}

out = (df.groupby(df['col1'].str[0].map(groups), as_index=False)
         .agg(lambda x: sorted(set().union(*x)))
       )

output:

                   col1
0  [1, 2, 4, 8, 10, 19]
1          [15, 16, 17]
2               [3, 18]

CodePudding user response:

Seems more like a Python problem than pandas one, so here's one attempt that checks every after list, merges (and removes) if intersecting:

vals = d["col1"]

# while there are at least 1 more list after to process...
i = 0
while i < len(vals) - 1:
    current = set(vals[i])

    # for the next lists...
    j = i   1
    while j < len(vals):
        # any intersection?
        # then update the current and delete the other
        other = vals[j]
        if current.intersection(other):
            current.update(other)
            del vals[j]
        else:
            # no intersection, so keep going for next lists
            j  = 1

    # put back the updated current back, and move on
    vals[i] = current
    i  = 1

at the end, vals is

In [108]: vals
Out[108]: [{1, 2, 4, 8, 10, 19}, {15, 16, 17}, {3, 18}]

In [109]: pd.Series(map(list, vals))
Out[109]:
0    [1, 2, 19, 4, 8, 10]
1            [16, 17, 15]
2                 [18, 3]
dtype: object

if you don't want vals modified, can chain .copy() for it.

  • Related