Home > OS >  Create graph with curved and labelled edges in Networkx
Create graph with curved and labelled edges in Networkx

Time:05-11

I'm currently displaying directed graphs with few nodes and edges connecting them, with nx.draw. The edges are labelled via nx.draw_networkx_edge_labels.

Now I wanted to "lighten" the "rigidity" aspect of the graph by setting the connectionstyle, which works fine with non-labelled edges.

Problem is that if I display labels, they are drawn as if the edge was not curved, which ends up creating a huge offset between the edge and the label.

Is there any way to work around this limitation? I could not find an "offset" option to nx.draw_networkx_edge_labels to address this issue.

EDIT:

Above is a quick example of the issue:

import matplotlib.pyplot as plt
import networkx as nx

tab = ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])

producer = tab[0]
consumers = tab[1]

color_map = []
DG = nx.DiGraph()
for i, cons in enumerate(consumers):
    DG.add_edge(producer, cons, label=f"edge-{i}")

for i in range(len(DG.nodes())):
    if i < 1   len(consumers):
        color_map.append("#DCE46F")
    else:
        color_map.append("#6FA2E4")
pos = nx.shell_layout(DG)
labels = nx.get_edge_attributes(DG, 'label')
nx.draw(DG, pos, node_color=color_map, connectionstyle="arc3, rad=0.2", with_labels=True, font_size=8, node_size=1000, node_shape='o')
nx.draw_networkx_edge_labels(DG, pos, edge_labels=labels)

plt.show()

current output:

enter image description here

CodePudding user response:

If you are open to using other libraries for the visualization, I wrote (and maintain) enter image description here

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

from netgraph import Graph # pip install netgraph

tab = ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])

producer = tab[0]
consumers = tab[1]

DG = nx.DiGraph()
for i, cons in enumerate(consumers):
    DG.add_edge(producer, cons, label=f"edge-{i}")

node_color = dict()
for node in DG:
    if node in producer:
        node_color[node] = "#DCE46F"
    else:
        node_color[node] = "#6FA2E4"

pos = nx.shell_layout(DG)
pos[producer] = pos[producer]   np.array([0.2, 0])
edge_labels = nx.get_edge_attributes(DG, 'label')

Graph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
      node_color=node_color, node_size=8.,
      node_labels=True, node_label_fontdict=dict(size=10),
      edge_labels=edge_labels, edge_label_fontdict=dict(size=10),
)
plt.show()
  • Related