Home > front end >  clickable netgraph in PyQt application with tabs
clickable netgraph in PyQt application with tabs

Time:05-12

I have a PyQt application which has a tab widget that can open any number of tabs. Each tab embeds a matplotlib canvas displaying graphs.

Lately, I have tried to implement InteractiveGraph from netgraph library, with little success despite the help of other stackoverflow similar topic. Maybe it comes from the additional presence of tabs, I don't know.

What I observe is that I can't manage to click on the graph nodes. Graph are properly displayed though.

Below is a quick example of my code (tabs have static graph values for testing, each added tab add an embedded graph), and how I tried to implement proposed solution of similar topic. I was not sure about necessity of using mpl_connect, so I tried with and without it and it didn't change anything.

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)

        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.canvas = FigureCanvas(Figure(figsize=(5, 3)))
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        self.canvas.mpl_connect('key_press_event', self.on_key_press)

        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)


    def on_key_press(self, event):
        print("you press", event.key)


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test "   str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter  = 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout(DG)
        pos[producer] = pos[producer]   np.array([0.2, 0])
        labels = nx.get_edge_attributes(DG, 'label')

        graph_instance = InteractiveGraph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=dataWidget.axe
                                          )

        dataWidget.canvas.draw()


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()

CodePudding user response:

If you what click on the figure, I think you need to connect mpl_connect on the figure, not on the canvas

self.figure = Figure(facecolor='white')
self.canvas = FigureCanvas(self.figure)

then make the connection with the 'button_press_event'

self.figure.canvas.mpl_connect('button_press_event', self.onclick)

If you want get the node that was clicked on, you can put DG and graph_instance in the data_tab class

    dataWidget.DG = DG
    dataWidget.graph_instance =graph_instance

then use them in the callback function

def onclick(self,event):
    print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
          ('double' if event.dblclick else 'single', event.button,
           event.x, event.y, event.xdata, event.ydata))
    x = event.xdata
    y = event.ydata
    for n in self.graph_instance.node_artists :
        node = self.graph_instance.node_artists [n]
        dist = ((x-node.xy[0])**2   (y-node.xy[1])**2)**0.5
        if dist < node.radius:
            print(node)

the function could be something like

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)
        self.parent = parent
        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.figure = Figure(facecolor='white')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        # self.canvas.mpl_connect('key_press_event', self.on_key_press)
        self.figure.canvas.mpl_connect('button_press_event', self.onclick)
        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)

    def onclick(self,event):
        print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
              ('double' if event.dblclick else 'single', event.button,
               event.x, event.y, event.xdata, event.ydata))
        x = event.xdata
        y = event.ydata
        for n in self.graph_instance.node_artists :
            node = self.graph_instance.node_artists [n]
            dist = ((x-node.xy[0])**2   (y-node.xy[1])**2)**0.5
            if dist < node.radius:
                print(node)



    def on_key_press(self, event):
        print("you press", event.key)


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test "   str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter  = 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in  DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout( DG)
        pos[producer] = pos[producer]   np.array([0.2, 0])
        labels = nx.get_edge_attributes( DG, 'label')

        graph_instance = InteractiveGraph( DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=dataWidget.axe,pickable=True
                                          )
        dataWidget.DG = DG
        dataWidget.graph_instance =graph_instance
        # dataWidget.canvas.show(pickable=True )
        dataWidget.canvas.draw()


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()

CodePudding user response:

OK, following ymmx answer, I figured out that the instanciation of InteractiveGraph needs to not be local.

So here's the solution which seems to behave as wanted in my case:

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)

        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.figure = Figure(figsize=(5, 3))
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        #self.figure.canvas.mpl_connect('button_press_event', self.onclick)

        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)

    def createInteractiveGraph(self, DG, pos, node_color, labels):

        self.graph_instance = InteractiveGraph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=self.axe
                                          )

        self.canvas.draw()


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test "   str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter  = 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout(DG)
        pos[producer] = pos[producer]   np.array([0.2, 0])
        labels = nx.get_edge_attributes(DG, 'label')

        dataWidget.createInteractiveGraph(DG, pos, node_color, labels)


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()
  • Related