I have a dataframe (df) similar to this:
Node_Start | Node_End |
---|---|
1.0 | 208.0 |
1.0 | 911.0 |
800.0 | 1.0 |
3.0 | 800.0 |
2.0 | 511.0 |
700.0 | 3.0 |
200.0 | 4.0 |
I would like to add a column that shows a related cluster based upon the values in columns 'Node_Start' and 'Node_End':
Node_Start | Node_End | Group |
---|---|---|
1.0 | 208.0 | 1 |
1.0 | 911.0 | 1 |
800.0 | 1.0 | 1 |
3.0 | 800.0 | 1 |
2.0 | 511.0 | 2 |
700.0 | 3.0 | 1 |
200.0 | 4.0 | 3 |
In other words, since 1.0 is in both 'Node_Start' & 'Node_End' it gets assigned to Group 1. Since 800.0 is connected to both 1.0 and 3.0, those rows also get assigned to Group 1. 2.0 and 511.0 are not related to any other row values and get assigned to Group 2. 200.0 and 4.0 are not related to any other rows and get assigned to Group 3. And so on...
The following code accomplishes the desired results but is a bit clunky and will not work on my entire dataset as it is too big (over 500,000 rows) and my kernel crashes before completing the job
def consolidate(sets):
# http://rosettacode.org/wiki/Set_consolidation#Python:_Iterative
setlist = [s for s in sets if s]
for i, s1 in enumerate(setlist):
if s1:
for s2 in setlist[i 1:]:
intersection = s1.intersection(s2)
if intersection:
s2.update(s1)
s1.clear()
s1 = s2
return [s for s in setlist if s]
def group_ids(pairs):
groups = consolidate(map(set, pairs))
d = {}
for i, group in enumerate(sorted(groups)):
for elem in group:
d[elem] = i
return d
CodePudding user response:
This looks to be a directed graph. Python has a nice module that deals with graphs: NetworkX
. And your problem seems to be about finding connected components.
So we could first build a graph (for the purposes of the problem, directedness is immaterial, so we drop that attribute here) where the nodes are the elements in df
and edges are the rows. Then create a mapping from nodes to component number using a dict comprehension and map
it to one of the columns:
import networkx as nx
arr = df.to_numpy()
G = nx.Graph()
G.add_edges_from(arr)
mapping = {node: i for i, component in enumerate(nx.connected_components(G), 1) for node in component}
df['Group'] = df['Node_Start'].map(mapping)
Output:
Node_Start Node_End Group
0 1.0 208.0 1
1 1.0 911.0 1
2 800.0 1.0 1
3 3.0 800.0 1
4 2.0 511.0 2
5 700.0 3.0 1
6 200.0 4.0 3