I'm currently displaying directed graphs with few nodes and edges connecting them, with nx.draw
.
The edges are labelled via nx.draw_networkx_edge_labels
.
Now I wanted to "lighten" the "rigidity" aspect of the graph by setting the connectionstyle
, which works fine with non-labelled edges.
Problem is that if I display labels, they are drawn as if the edge was not curved, which ends up creating a huge offset between the edge and the label.
Is there any way to work around this limitation? I could not find an "offset" option to nx.draw_networkx_edge_labels
to address this issue.
EDIT:
Above is a quick example of the issue:
import matplotlib.pyplot as plt
import networkx as nx
tab = ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])
producer = tab[0]
consumers = tab[1]
color_map = []
DG = nx.DiGraph()
for i, cons in enumerate(consumers):
DG.add_edge(producer, cons, label=f"edge-{i}")
for i in range(len(DG.nodes())):
if i < 1 len(consumers):
color_map.append("#DCE46F")
else:
color_map.append("#6FA2E4")
pos = nx.shell_layout(DG)
labels = nx.get_edge_attributes(DG, 'label')
nx.draw(DG, pos, node_color=color_map, connectionstyle="arc3, rad=0.2", with_labels=True, font_size=8, node_size=1000, node_shape='o')
nx.draw_networkx_edge_labels(DG, pos, edge_labels=labels)
plt.show()
current output:
CodePudding user response:
If you are open to using other libraries for the visualization, I wrote (and maintain)
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from netgraph import Graph # pip install netgraph
tab = ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])
producer = tab[0]
consumers = tab[1]
DG = nx.DiGraph()
for i, cons in enumerate(consumers):
DG.add_edge(producer, cons, label=f"edge-{i}")
node_color = dict()
for node in DG:
if node in producer:
node_color[node] = "#DCE46F"
else:
node_color[node] = "#6FA2E4"
pos = nx.shell_layout(DG)
pos[producer] = pos[producer] np.array([0.2, 0])
edge_labels = nx.get_edge_attributes(DG, 'label')
Graph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
node_color=node_color, node_size=8.,
node_labels=True, node_label_fontdict=dict(size=10),
edge_labels=edge_labels, edge_label_fontdict=dict(size=10),
)
plt.show()