Home > Mobile >  How to relate QGraphicsScene position to position on a matplotlib plot axis
How to relate QGraphicsScene position to position on a matplotlib plot axis

Time:09-09

I have a window which holds two plots, one is a 2D plot and the other is a selection of a cut along the y-axis of that plot. I would like to be able to select what cut I want by moving a horizontal bar up to a position and having the 1D plot update. I am having trouble relating the position from the scene where the line item is, to the axes on the plot. The part where I would be figuring this out is in the main window in the function defined plot_position. I am also open to other ideas of how to go about this. Here is a screen-shot for reference:

enter image description here

import sys
from PyQt5.Qt import Qt, QObject, QPen, QPointF
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QSizePolicy
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsLineItem, QGraphicsView, \
    QGraphicsScene, QWidget, QHBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import xarray as xr
import numpy as np


class Signals(QObject):
    bttnReleased = pyqtSignal(float, float)


class HLineItem(QGraphicsLineItem):

    def __init__(self, signals):
        super(HLineItem, self).__init__()
        self.signals = signals
        self.setPen(QPen(Qt.red, 3))
        self.setFlag(QGraphicsLineItem.ItemIsMovable)
        self.setCursor(Qt.OpenHandCursor)
        self.setAcceptHoverEvents(True)

    def mouseMoveEvent(self, event):
        orig_cursor_position = event.lastScenePos()
        updated_cursor_position = event.scenePos()

        orig_position = self.scenePos()
        updated_cursor_y = updated_cursor_position.y() - \
                           orig_cursor_position.y()   orig_position.y()
        self.setPos(QPointF(orig_position.x(), updated_cursor_y))

    def mouseReleaseEvent(self, event):
        x_pos = event.scenePos().x()
        y_pos = event.scenePos().y()
        self.signals.bttnReleased.emit(x_pos, y_pos)


class PlotCanvas(FigureCanvas):

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super(PlotCanvas, self).__init__(self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.data = xr.DataArray()
        self.axes = None

    def plot(self, data):
        self.data = data
        self.axes = self.fig.add_subplot(111)
        self.data.plot(ax=self.axes)
        self.axes.set_xlim(-.5, .5)
        self.draw()


class MainWindow(QMainWindow):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.signals = Signals()
        x = np.linspace(-1, 1, 51)
        y = np.linspace(-1, 1, 51)
        z = np.linspace(-1, 1, 51)
        xyz = np.meshgrid(x, y, z, indexing='ij')
        d = np.sin(np.pi * np.exp(-1 * (xyz[0] ** 2   xyz[1] ** 2   xyz[2] ** 2))) * np.cos(np.pi / 2 * xyz[1])
        self.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
        self.cut = self.xar.sel({"perp": 0}, method='nearest')
        self.edc = self.cut.sel({'slit': 0}, method='nearest')
        self.canvas = PlotCanvas()
        self.canvas_edc = PlotCanvas()
        self.canvas.plot(self.cut)
        self.canvas_edc.plot(self.edc)

        self.view = QGraphicsView()
        self.scene = QGraphicsScene()
        self.line = HLineItem(self.signals)
        self.line_pos = [0, 0]
        self.layout1 = QHBoxLayout()
        self.layout2 = QHBoxLayout()
        self.connect_scene()

        self.layout1.addWidget(self.view)
        self.layout2.addWidget(self.canvas_edc)
        self.central = QWidget()
        self.main_layout = QHBoxLayout()
        self.main_layout.addLayout(self.layout1)
        self.main_layout.addLayout(self.layout2)
        self.central.setLayout(self.main_layout)
        self.setCentralWidget(self.central)

        self.signals.bttnReleased.connect(self.plot_position)

    def connect_scene(self):
        s = self.canvas.figure.get_size_inches() * self.canvas.figure.dpi
        self.view.setScene(self.scene)
        self.scene.addWidget(self.canvas)
        # self.scene.setSceneRect(0, 0, s[0], s[1])
        # self.capture_scene_change()
        self.line.setLine(0, 0, self.scene.sceneRect().width(), 0)
        self.line.setPos(self.line_pos[0], self.line_pos[1])
        self.scene.addItem(self.line)

    def handle_plotting(self):
        self.clearLayout(self.layout2)
        self.refresh_edc()

    def refresh_edc(self):
        self.canvas_edc = PlotCanvas()
        self.canvas_edc.plot(self.edc)
        self.layout2.addWidget(self.canvas_edc)

    def clearLayout(self, layout):
        while layout.count():
            child = layout.takeAt(0)
            if child.widget():
                child.widget().deleteLater()

    def plot_position(self, x, y):
        self.line_pos = [x, y]
        plot_bbox = self.canvas.axes.get_position()
        # something here to relate the hline position in the scene to the axes positions/plot axes
        sel_val = min(self.cut.slit, key=lambda f: abs(f - y))  # this should not be y, but rather
        # the corresponding value on the y-axis
        self.edc = self.cut.sel({"slit": 0}, method='nearest')  # this should not be 0, but rather
        # the sel_val
        self.handle_plotting()


class App(QApplication):
    def __init__(self, sys_argv):
        super(App, self).__init__(sys_argv)
        self.setAttribute(Qt.AA_EnableHighDpiScaling)
        self.mainWindow = MainWindow()
        self.mainWindow.setWindowTitle("arpys")
        self.mainWindow.show()


def main():

    app = App(sys.argv)
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()

CodePudding user response:

This took me quite a while to figure out. The frustrating part was dealing with Matplotlib's positions of their objects and understanding what classes include what. For a while I was trying to use axes.get_position but this was returning a value for the height that was way too small. I am still not sure for this function what they define as y1 and y0 to give the height. Next I looked at axes.get_window_extent (had a similar problem) and axes.get_tightbbox. the tightbbox function returns the bounding box of the axes including their decorators (xlabel, title, etc) which was better but was actually giving me a height that was beyond the extent I wanted since it included these decorators. Finally, I found that what I really wanted with the length of the spine! In the end I used the spine.get_window_extent() function and this was exactly what I needed. I have attached the updated code.

It was helpful to look at this diagram of the inheritances to see what I was dealing with/looking for.

import sys
from PyQt5.Qt import Qt, QObject, QPen, QPointF
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QSizePolicy
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsLineItem, QGraphicsView, \
    QGraphicsScene, QWidget, QHBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import xarray as xr
import numpy as np


class Signals(QObject):
    bttnReleased = pyqtSignal(float)


class HLineItem(QGraphicsLineItem):

    def __init__(self, signals):
        super(HLineItem, self).__init__()
        self.signals = signals
        self.setPen(QPen(Qt.red, 3))
        self.setFlag(QGraphicsLineItem.ItemIsMovable)
        self.setCursor(Qt.OpenHandCursor)
        self.setAcceptHoverEvents(True)

    def mouseMoveEvent(self, event):
        orig_cursor_position = event.lastScenePos()
        updated_cursor_position = event.scenePos()

        orig_position = self.scenePos()
        updated_cursor_y = updated_cursor_position.y() - \
                           orig_cursor_position.y()   orig_position.y()
        self.setPos(QPointF(orig_position.x(), updated_cursor_y))

    def mouseReleaseEvent(self, event):
        y_pos = event.scenePos().y()
        self.signals.bttnReleased.emit(y_pos)


class PlotCanvas(FigureCanvas):

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super(PlotCanvas, self).__init__(self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.data = xr.DataArray()
        self.axes = None

    def plot(self, data):
        self.data = data
        self.axes = self.fig.add_subplot(111)
        self.data.plot(ax=self.axes)
        self.fig.subplots_adjust(left=0.2)
        self.fig.subplots_adjust(bottom=0.2)
        self.axes.set_xlim(-.5, .5)
        self.draw()


class MainWindow(QMainWindow):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.signals = Signals()
        x = np.linspace(-1, 1, 51)
        y = np.linspace(-1, 1, 51)
        z = np.linspace(-1, 1, 51)
        xyz = np.meshgrid(x, y, z, indexing='ij')
        d = np.sin(np.pi * np.exp(-1 * (xyz[0] ** 2   xyz[1] ** 2   xyz[2] ** 2))) * np.cos(np.pi / 2 * xyz[1])
        self.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
        self.cut = self.xar.sel({"perp": 0}, method='nearest')
        self.edc = self.cut.sel({'slit': 0}, method='nearest')
        self.canvas = PlotCanvas()
        self.canvas_edc = PlotCanvas()
        self.canvas.plot(self.cut)
        self.canvas_edc.plot(self.edc)

        self.view = QGraphicsView()
        self.scene = QGraphicsScene()
        self.line = HLineItem(self.signals)
        self.line_pos = [0, 0]
        self.layout1 = QHBoxLayout()
        self.layout2 = QHBoxLayout()
        self.connect_scene()

        self.layout1.addWidget(self.view)
        self.layout2.addWidget(self.canvas_edc)
        self.central = QWidget()
        self.main_layout = QHBoxLayout()
        self.main_layout.addLayout(self.layout1)
        self.main_layout.addLayout(self.layout2)
        self.central.setLayout(self.main_layout)
        self.setCentralWidget(self.central)

        self.signals.bttnReleased.connect(self.plot_position)

    def connect_scene(self):
        s = self.canvas.figure.get_size_inches() * self.canvas.figure.dpi
        self.view.setScene(self.scene)
        self.scene.addWidget(self.canvas)
        self.scene.setSceneRect(0, 0, s[0], s[1])
        # self.capture_scene_change()
        self.line.setLine(0, 0, self.scene.sceneRect().width(), 0)
        self.line.setPos(self.line_pos[0], self.line_pos[1])
        self.scene.addItem(self.line)

    def handle_plotting(self):
        self.clearLayout(self.layout2)
        self.refresh_edc()

    def refresh_edc(self):
        self.canvas_edc = PlotCanvas()
        self.canvas_edc.plot(self.edc)
        self.layout2.addWidget(self.canvas_edc)

    def clearLayout(self, layout):
        while layout.count():
            child = layout.takeAt(0)
            if child.widget():
                child.widget().deleteLater()

    def plot_position(self, y):
        rel_pos = lambda x: abs(self.scene.sceneRect().height() - x)
        bbox = self.canvas.axes.spines['left'].get_window_extent()
        plot_bbox = [bbox.y0, bbox.y1]
        if rel_pos(y) < plot_bbox[0]:
            self.line.setPos(0, rel_pos(plot_bbox[0]))
        elif rel_pos(y) > plot_bbox[1]:
            self.line.setPos(0, rel_pos(plot_bbox[1]))
        self.line_pos = self.line.pos().y()
        size_range = len(self.cut.slit)
        r = np.linspace(plot_bbox[0], plot_bbox[1], size_range).tolist()
        corr = list(zip(r, self.cut.slit.values))
        sel_val = min(r, key=lambda f: abs(f - rel_pos(self.line_pos)))
        what_index = r.index(sel_val)
        self.edc = self.cut.sel({"slit": corr[what_index][1]}, method='nearest')
        self.handle_plotting()


class App(QApplication):
    def __init__(self, sys_argv):
        super(App, self).__init__(sys_argv)
        self.setAttribute(Qt.AA_EnableHighDpiScaling)
        self.mainWindow = MainWindow()
        self.mainWindow.setWindowTitle("arpys")
        self.mainWindow.show()


def main():

    app = App(sys.argv)
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()
  • Related