I have a window which holds two plots, one is a 2D plot and the other is a selection of a cut along the y-axis of that plot. I would like to be able to select what cut I want by moving a horizontal bar up to a position and having the 1D plot update. I am having trouble relating the position from the scene where the line item is, to the axes on the plot. The part where I would be figuring this out is in the main window in the function defined plot_position. I am also open to other ideas of how to go about this. Here is a screen-shot for reference:
import sys
from PyQt5.Qt import Qt, QObject, QPen, QPointF
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QSizePolicy
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsLineItem, QGraphicsView, \
QGraphicsScene, QWidget, QHBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import xarray as xr
import numpy as np
class Signals(QObject):
bttnReleased = pyqtSignal(float, float)
class HLineItem(QGraphicsLineItem):
def __init__(self, signals):
super(HLineItem, self).__init__()
self.signals = signals
self.setPen(QPen(Qt.red, 3))
self.setFlag(QGraphicsLineItem.ItemIsMovable)
self.setCursor(Qt.OpenHandCursor)
self.setAcceptHoverEvents(True)
def mouseMoveEvent(self, event):
orig_cursor_position = event.lastScenePos()
updated_cursor_position = event.scenePos()
orig_position = self.scenePos()
updated_cursor_y = updated_cursor_position.y() - \
orig_cursor_position.y() orig_position.y()
self.setPos(QPointF(orig_position.x(), updated_cursor_y))
def mouseReleaseEvent(self, event):
x_pos = event.scenePos().x()
y_pos = event.scenePos().y()
self.signals.bttnReleased.emit(x_pos, y_pos)
class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100):
self.fig = Figure(figsize=(width, height), dpi=dpi)
super(PlotCanvas, self).__init__(self.fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self,
QSizePolicy.Expanding,
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.data = xr.DataArray()
self.axes = None
def plot(self, data):
self.data = data
self.axes = self.fig.add_subplot(111)
self.data.plot(ax=self.axes)
self.axes.set_xlim(-.5, .5)
self.draw()
class MainWindow(QMainWindow):
def __init__(self):
super(MainWindow, self).__init__()
self.signals = Signals()
x = np.linspace(-1, 1, 51)
y = np.linspace(-1, 1, 51)
z = np.linspace(-1, 1, 51)
xyz = np.meshgrid(x, y, z, indexing='ij')
d = np.sin(np.pi * np.exp(-1 * (xyz[0] ** 2 xyz[1] ** 2 xyz[2] ** 2))) * np.cos(np.pi / 2 * xyz[1])
self.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
self.cut = self.xar.sel({"perp": 0}, method='nearest')
self.edc = self.cut.sel({'slit': 0}, method='nearest')
self.canvas = PlotCanvas()
self.canvas_edc = PlotCanvas()
self.canvas.plot(self.cut)
self.canvas_edc.plot(self.edc)
self.view = QGraphicsView()
self.scene = QGraphicsScene()
self.line = HLineItem(self.signals)
self.line_pos = [0, 0]
self.layout1 = QHBoxLayout()
self.layout2 = QHBoxLayout()
self.connect_scene()
self.layout1.addWidget(self.view)
self.layout2.addWidget(self.canvas_edc)
self.central = QWidget()
self.main_layout = QHBoxLayout()
self.main_layout.addLayout(self.layout1)
self.main_layout.addLayout(self.layout2)
self.central.setLayout(self.main_layout)
self.setCentralWidget(self.central)
self.signals.bttnReleased.connect(self.plot_position)
def connect_scene(self):
s = self.canvas.figure.get_size_inches() * self.canvas.figure.dpi
self.view.setScene(self.scene)
self.scene.addWidget(self.canvas)
# self.scene.setSceneRect(0, 0, s[0], s[1])
# self.capture_scene_change()
self.line.setLine(0, 0, self.scene.sceneRect().width(), 0)
self.line.setPos(self.line_pos[0], self.line_pos[1])
self.scene.addItem(self.line)
def handle_plotting(self):
self.clearLayout(self.layout2)
self.refresh_edc()
def refresh_edc(self):
self.canvas_edc = PlotCanvas()
self.canvas_edc.plot(self.edc)
self.layout2.addWidget(self.canvas_edc)
def clearLayout(self, layout):
while layout.count():
child = layout.takeAt(0)
if child.widget():
child.widget().deleteLater()
def plot_position(self, x, y):
self.line_pos = [x, y]
plot_bbox = self.canvas.axes.get_position()
# something here to relate the hline position in the scene to the axes positions/plot axes
sel_val = min(self.cut.slit, key=lambda f: abs(f - y)) # this should not be y, but rather
# the corresponding value on the y-axis
self.edc = self.cut.sel({"slit": 0}, method='nearest') # this should not be 0, but rather
# the sel_val
self.handle_plotting()
class App(QApplication):
def __init__(self, sys_argv):
super(App, self).__init__(sys_argv)
self.setAttribute(Qt.AA_EnableHighDpiScaling)
self.mainWindow = MainWindow()
self.mainWindow.setWindowTitle("arpys")
self.mainWindow.show()
def main():
app = App(sys.argv)
sys.exit(app.exec_())
if __name__ == "__main__":
main()
CodePudding user response:
This took me quite a while to figure out. The frustrating part was dealing with Matplotlib's positions of their objects and understanding what classes include what. For a while I was trying to use axes.get_position but this was returning a value for the height that was way too small. I am still not sure for this function what they define as y1 and y0 to give the height. Next I looked at axes.get_window_extent (had a similar problem) and axes.get_tightbbox. the tightbbox function returns the bounding box of the axes including their decorators (xlabel, title, etc) which was better but was actually giving me a height that was beyond the extent I wanted since it included these decorators. Finally, I found that what I really wanted with the length of the spine! In the end I used the spine.get_window_extent() function and this was exactly what I needed. I have attached the updated code.
It was helpful to look at this diagram of the inheritances to see what I was dealing with/looking for.
import sys
from PyQt5.Qt import Qt, QObject, QPen, QPointF
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QSizePolicy
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsLineItem, QGraphicsView, \
QGraphicsScene, QWidget, QHBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import xarray as xr
import numpy as np
class Signals(QObject):
bttnReleased = pyqtSignal(float)
class HLineItem(QGraphicsLineItem):
def __init__(self, signals):
super(HLineItem, self).__init__()
self.signals = signals
self.setPen(QPen(Qt.red, 3))
self.setFlag(QGraphicsLineItem.ItemIsMovable)
self.setCursor(Qt.OpenHandCursor)
self.setAcceptHoverEvents(True)
def mouseMoveEvent(self, event):
orig_cursor_position = event.lastScenePos()
updated_cursor_position = event.scenePos()
orig_position = self.scenePos()
updated_cursor_y = updated_cursor_position.y() - \
orig_cursor_position.y() orig_position.y()
self.setPos(QPointF(orig_position.x(), updated_cursor_y))
def mouseReleaseEvent(self, event):
y_pos = event.scenePos().y()
self.signals.bttnReleased.emit(y_pos)
class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100):
self.fig = Figure(figsize=(width, height), dpi=dpi)
super(PlotCanvas, self).__init__(self.fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self,
QSizePolicy.Expanding,
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.data = xr.DataArray()
self.axes = None
def plot(self, data):
self.data = data
self.axes = self.fig.add_subplot(111)
self.data.plot(ax=self.axes)
self.fig.subplots_adjust(left=0.2)
self.fig.subplots_adjust(bottom=0.2)
self.axes.set_xlim(-.5, .5)
self.draw()
class MainWindow(QMainWindow):
def __init__(self):
super(MainWindow, self).__init__()
self.signals = Signals()
x = np.linspace(-1, 1, 51)
y = np.linspace(-1, 1, 51)
z = np.linspace(-1, 1, 51)
xyz = np.meshgrid(x, y, z, indexing='ij')
d = np.sin(np.pi * np.exp(-1 * (xyz[0] ** 2 xyz[1] ** 2 xyz[2] ** 2))) * np.cos(np.pi / 2 * xyz[1])
self.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
self.cut = self.xar.sel({"perp": 0}, method='nearest')
self.edc = self.cut.sel({'slit': 0}, method='nearest')
self.canvas = PlotCanvas()
self.canvas_edc = PlotCanvas()
self.canvas.plot(self.cut)
self.canvas_edc.plot(self.edc)
self.view = QGraphicsView()
self.scene = QGraphicsScene()
self.line = HLineItem(self.signals)
self.line_pos = [0, 0]
self.layout1 = QHBoxLayout()
self.layout2 = QHBoxLayout()
self.connect_scene()
self.layout1.addWidget(self.view)
self.layout2.addWidget(self.canvas_edc)
self.central = QWidget()
self.main_layout = QHBoxLayout()
self.main_layout.addLayout(self.layout1)
self.main_layout.addLayout(self.layout2)
self.central.setLayout(self.main_layout)
self.setCentralWidget(self.central)
self.signals.bttnReleased.connect(self.plot_position)
def connect_scene(self):
s = self.canvas.figure.get_size_inches() * self.canvas.figure.dpi
self.view.setScene(self.scene)
self.scene.addWidget(self.canvas)
self.scene.setSceneRect(0, 0, s[0], s[1])
# self.capture_scene_change()
self.line.setLine(0, 0, self.scene.sceneRect().width(), 0)
self.line.setPos(self.line_pos[0], self.line_pos[1])
self.scene.addItem(self.line)
def handle_plotting(self):
self.clearLayout(self.layout2)
self.refresh_edc()
def refresh_edc(self):
self.canvas_edc = PlotCanvas()
self.canvas_edc.plot(self.edc)
self.layout2.addWidget(self.canvas_edc)
def clearLayout(self, layout):
while layout.count():
child = layout.takeAt(0)
if child.widget():
child.widget().deleteLater()
def plot_position(self, y):
rel_pos = lambda x: abs(self.scene.sceneRect().height() - x)
bbox = self.canvas.axes.spines['left'].get_window_extent()
plot_bbox = [bbox.y0, bbox.y1]
if rel_pos(y) < plot_bbox[0]:
self.line.setPos(0, rel_pos(plot_bbox[0]))
elif rel_pos(y) > plot_bbox[1]:
self.line.setPos(0, rel_pos(plot_bbox[1]))
self.line_pos = self.line.pos().y()
size_range = len(self.cut.slit)
r = np.linspace(plot_bbox[0], plot_bbox[1], size_range).tolist()
corr = list(zip(r, self.cut.slit.values))
sel_val = min(r, key=lambda f: abs(f - rel_pos(self.line_pos)))
what_index = r.index(sel_val)
self.edc = self.cut.sel({"slit": corr[what_index][1]}, method='nearest')
self.handle_plotting()
class App(QApplication):
def __init__(self, sys_argv):
super(App, self).__init__(sys_argv)
self.setAttribute(Qt.AA_EnableHighDpiScaling)
self.mainWindow = MainWindow()
self.mainWindow.setWindowTitle("arpys")
self.mainWindow.show()
def main():
app = App(sys.argv)
sys.exit(app.exec_())
if __name__ == "__main__":
main()