I am trying to animate a graph whose edges widths and color change over time. My code works, but it is extremely slow. I imagine there are more efficient implementations.
def minimal_graph(datas, pos):
frames = len(datas[0])
fig, axes = plt.subplots(1, 2)
axes = axes.flatten()
for j, dat in enumerate(datas):
G = nx.from_numpy_matrix(dat[0])
nx.draw(G, pos, ax=axes[j])
def update(it, data, pos, ax):
print(it)
for i, dat in enumerate(data):
# This is the problematic line, because I clear the axis hence
# everything has to be drawn from scratch every time.
ax[i].clear()
G = nx.from_numpy_matrix(dat[it])
edges, weights = zip(*nx.get_edge_attributes(G, 'weight').items())
nx.draw(
G,
pos,
node_color='#24FF00',
edgelist=edges,
edge_color=weights,
width=weights,
edge_vmin=-5,
edge_vmax=5,
ax=ax[i])
ani = animation.FuncAnimation(fig, update, frames=frames, fargs=(
datas, pos, axes), interval=100)
ani.save('temp/graphs.mp4')
plt.close()
dataset1 = []
dataset2 = []
for i in range(100):
arr1 = np.random.rand(400, 400)
arr2 = np.random.rand(400, 400)
dataset1.append(arr1)
dataset2.append(arr2)
datasets = [dataset1, dataset2]
G = nx.from_numpy_matrix(dataset1[0])
pos = nx.spring_layout(G)
minimal_graph(datasets, pos)
As pointed out in the code, the problem is that at every frame I redraw the graph from "scratch". When using animations in matplotlib, I usually try to create lines and use the function '''line.set_data()''', which I know is a lot faster. It's just that I don't know how to set that for a graph using networkx. I found