Source code for schrodinger.trajectory.trajectory_gui_dir.plots

"""
File containing plot related code used in the Trajectory Plots GUI
"""
import csv
from enum import Enum
from enum import auto

import openpyxl

from schrodinger.Qt import QtCore
from schrodinger.Qt import QtGui
from schrodinger.Qt import QtWidgets
from schrodinger.Qt import QtCharts
from schrodinger.Qt.QtCore import Qt
from schrodinger.ui.qt import basewidgets
from schrodinger.ui.qt import filedialog
from schrodinger.ui.qt.standard.icons import icons
from schrodinger.utils import csv_unicode
from schrodinger.tasks import tasks
from schrodinger import structure
from schrodinger import get_maestro

maestro = get_maestro()

from . import energy_plot_ui
from . import collapsible_plot_ui
from . import shortcut_ui
from . import traj_plot_models
from . import energy_plots
from . import plots as tplots
from .traj_plot_models import AnalysisMode

# Plot Constants
MAX_AXIS_TICKS = 5
MIN_AXIS_SPAN = 0.1
MAX_SHORTCUTS_IN_ROW = 3
IMAGE_WIDTH = 20000

# Plot context menu actions
SHOW = "Show"
HIDE = "Hide"
SAVE_IMG = "Save Image..."
EXPORT_CSV = "Export as CSV..."
EXPORT_EXCEL = "Export to Excel..."
DELETE = "Delete"
VIEW_PLOT = "View Plot..."


#############################
# ENUMS
#############################
[docs]class TrajectoryPlotType(Enum): """ Enum of plot types to generate """ MEASUREMENT_WORKSPACE = auto() MEASUREMENT_ADD = auto() MEASUREMENT_PLANAR_ANGLE = auto() MEASUREMENT_CENTROID = auto() INTERACTIONS_ALL = auto() INTERACTIONS_HYDROGEN_BONDS = auto() INTERACTIONS_HALOGEN_BONDS = auto() INTERACTIONS_SALT_BRIDGE = auto() INTERACTIONS_PI_PI = auto() INTERACTIONS_CAT_PI = auto() DESCRIPTORS_RMSD = auto() DESCRIPTORS_ATOM_RMSF = auto() DESCRIPTORS_RES_RMSF = auto() DESCRIPTORS_RADIUS_GYRATION = auto() DESCRIPTORS_PSA = auto() DESCRIPTORS_SASA = auto() DESCRIPTORS_MOLECULAR_SA = auto() ENERGY_ALL_GROUPED = auto() ENERGY_ALL_INDIVIDUAL = auto() ENERGY_INDIVIDUAL_MOLECULES = auto() ENERGY_CUSTOM_SUBSTRUCTURE_SETS = auto() ENERGY_CUSTOM_ASL_SETS = auto()
ENERGY_PLOT_TYPES = { TrajectoryPlotType.ENERGY_ALL_GROUPED, TrajectoryPlotType.ENERGY_ALL_INDIVIDUAL, TrajectoryPlotType.ENERGY_INDIVIDUAL_MOLECULES, TrajectoryPlotType.ENERGY_CUSTOM_SUBSTRUCTURE_SETS, TrajectoryPlotType.ENERGY_CUSTOM_ASL_SETS } PlotDataType = Enum('PlotDataType', ('RMSF', 'TRAJECTORY', 'ENERGY')) ############################# # Plot Formatting Functions #############################
[docs]def handle_chart_legend(chart, is_multiseries_interactions): """ Sets the chart legend depending on the type of the chart :param chart: Chart containing legend :type chart: QtCharts.QChart :param is_multiseries_interactions: is this a multiseries interaction plot :type is_multiseries_interactions: bool """ legend = chart.legend() if is_multiseries_interactions: legend.setShowToolTips(True) legend.setAlignment(Qt.AlignBottom) else: legend.hide()
def _generateAxisSpecifications(data, axis): """ Sets axis values based on provided data :param data: Data for series on axis :type data: list :param axis: Axis to set :type axis: QValueAxis """ axis_values = set(round(val, 1) for val in data) # set min axis_min = min(axis_values) if axis.min() < axis_min: axis_min -= MIN_AXIS_SPAN axis.setMin(axis_min) # set max axis_max = max(axis_values) if axis.max() > axis_max: axis_max += MIN_AXIS_SPAN axis.setMax(axis_max) # set ticks num_ticks = min(MAX_AXIS_TICKS, (axis_max - axis_min + MIN_AXIS_SPAN) / MIN_AXIS_SPAN) axis.setTickCount(num_ticks) ############################# # TRADITIONAL PLOTS #############################
[docs]class AbstractTrajectoryChartView(QtCharts.QChartView): """ QChartView subclass shared by all trajectory plots. """ # This signal is emitted when a point is clicked in the view: plotClicked = QtCore.pyqtSignal(QtCore.QPointF)
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setMouseTracking(True) self._mouse_press_pos = None
[docs] def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self._mouse_press_pos = event.pos() super().mousePressEvent(event)
[docs] def mouseReleaseEvent(self, event): """ Find the frame that the user's left click selected. Display selection used in the task input. """ if event.button() == Qt.LeftButton: release_pos = event.pos() if release_pos == self._mouse_press_pos: # User has not dragged the mouse. value = self.chart().mapToValue(release_pos) self.plotClicked.emit(value) super().mouseReleaseEvent(event)
[docs]class AbstractTrajectoryPlotManager(QtCore.QObject): """ Base class for storing plot data for trajectory analysis data. Also holds a reference to the chart view. Note that we cannot simply use QtCharts.QLineSeries.clicked for this because it does not appear to trigger on OS X. :ivar displayAsl: Display the asl for the corresponding entry id Signal. args are (asl, entry_id) :type displayAsl: `QtCore.pyqtSignal(str, int)` :ivar displayFrameAndAsl: Change frame and show ASL for given entry id Signal args are (asl, entry_id, frame_number) :type displayFrameAndAsl: `QtCore.pyqtSignal(str, int, int)` :ivar newDataAvailable: Emitted when plot has finished generating data. :type newDataAvailable: `QtCore.pyqtSignal()` :ivar showWarning: Emitted when plot requests parent panel to show a warning message. :type showWarning: `QtCore.pyqtSignal()` """ deletePlot = QtCore.pyqtSignal() displayAsl = QtCore.pyqtSignal(str, int) displayFrameAndAsl = QtCore.pyqtSignal(str, int, int) newDataAvailable = QtCore.pyqtSignal() showWarning = QtCore.pyqtSignal(str) PANEL_CLASS = None
[docs] def __init__(self, panel): super().__init__(panel) self.panel = panel self.entry_traj = panel.entry_traj self.cms_fpath, self.trj_dir = self.entry_traj.getCMSPathAndTrajDir() self.view = AbstractTrajectoryChartView(parent=panel) self.window = self.view.window() self.view.plotClicked.connect(self.onPlotClicked) self.settings_hash = None self.series_map = {} # Traj Player widgets use 1-based indexing so we so here as well. trj = self.entry_traj.trajectory self.time_to_frame_map = { fr.time / 1000: idx for idx, fr in enumerate(trj, start=1) } self.collapsible_plot_widget = None
[docs] def configureTask(self, task): """ Configure the specified task by assigning it the analysis mode, and assigning inputs based on the currently selected trajectory. Task's taskDone signal is connected to self._onTaskEnded(). """ entry_traj = self.entry_traj task.specifyTaskDir(tasks.TEMP_TASKDIR) # Specify the path to the trajectory loaded into main panel: task.input.cms_fname = self.cms_fpath # NOTE: These are only needed for blocking task (CMS model is not # serializable). For blocking tasks, linking model directly is faster # then re-loading it from disk. if isinstance(task, traj_plot_models.tasks.BlockingFunctionTask): task.input.msys_model = entry_traj.msys_model task.input.cms_model = entry_traj.cms_model task.input.trajectory = entry_traj.trajectory task.taskDone.connect(lambda: self._onTaskEnded(task))
[docs] def setupView(self, fixed_height=250, multi_series=False): chart = self.chart() # Remove as much unnecessary padding from the chart as possible: chart.layout().setContentsMargins(0, 0, 0, 0) chart.setWindowFrameMargins(0, 0, 0, 0) chart.setBackgroundRoundness(0) tplots.handle_chart_legend(chart, multi_series) self.view.setMouseTracking(True) if not self.PANEL_CLASS: self.view.setFixedHeight(fixed_height)
[docs] def start(self): self.task.start()
[docs] def createPanel(self): """ For advanced plots, creates an instance of a plot panel and returns it. For simple (panel-less) plots, returns None. """ if self.PANEL_CLASS: return self.PANEL_CLASS(self) return None
[docs] def getPlotTitle(self): plot_title, _ = self.getInitialPlotTitleAndTooltip() return plot_title
def _validateTask(self, task): """ Validates task completed correctly. Throws warnings on issues. :param task: Finished task to be processed :type task: `traj_plot_model.TrajectoryAnalysisTask` :return: Whether task is valid :rtype: bool """ if task.status == task.FAILED: log_file = task.getTaskFilename(task.name + '.log') msg = f"Job failed. For more info, see:\n{log_file}" self.showWarning.emit(msg) return False return True def _onTaskEnded(self, task): """ Create a plot from the data of a completed task. Used for single-series plots. :param task: Finished task to be processed :type task: tasks.AbstractTask """ if not self._validateTask(task): return try: self.loadFromTask(task) except RuntimeError as err: self.showWarning.emit(str(err)) else: self.newDataAvailable.emit()
[docs] def createCollapsiblePlotWidget(self): plot_title, plot_tooltip = self.getInitialPlotTitleAndTooltip() plot_widget = CollapsiblePlot( parent=self.panel, system_title=self.entry_traj.cms_model.fsys_ct.title, plot_title=plot_title, plot=self, tooltip=plot_tooltip) self.collapsible_plot_widget = plot_widget return plot_widget
[docs] def isRunning(self): """ Return True if the plot is still generating data. """ return self.task.status is self.task.RUNNING
[docs] def chart(self): return self.view.chart()
[docs] def onPlotClicked(self, value): if not self.task: return if self.time_to_frame_map: time = value.x() frame_idx = self._getNearestFrameForTime(time) if frame_idx is not None: eid = self.entry_traj.eid self.displayFrameAndAsl.emit(self.task.input.fit_asl, eid, frame_idx)
def _getNearestFrameForTime(self, time): """ Given a time value, return the frame nearest to that time. :param time: Time to get the nearest frame for :type time: float :return: 1-based frame index closest to the specified time or None if time is out of range. :rtype: int or None """ all_times = list(self.time_to_frame_map.keys()) if time < min(all_times) or time > max(all_times): return None nearest_key = None for key in self.time_to_frame_map: if nearest_key is None or abs(time - key) < abs(time - nearest_key): nearest_key = key return self.time_to_frame_map[nearest_key]
[docs] def getPlotType(self): """ Returns what type of plot this class uses; used for grouping export data. """ raise NotImplementedError
[docs] def getDataForExport(self): """ Return a list of row data to export to CSV or Excel. Subclasses must override. :return: Data to be exported :rtype: list(list) """ raise NotImplementedError
[docs] def getInitialPlotTitleAndTooltip(self): """ Return the plot title and tooltip for this plot. :return: Plot title, Plot tooltip. :rtype: (str, str or None) """ raise NotImplementedError
[docs] def getSettingsHash(self): task = self.task return self.generateSettingsHash([ task.input.analysis_mode.name, task.input.additional_args, task.input.additional_kwargs ])
[docs] def generateSettingsHash(self, settings_list): """ Return a tuple that uniquely identifies this plot. :param settings_list: List of settings that can uniquely identify the plot. In addition to these, the plot class name and trajectory path will be added. :type settings_list: list :return: Unique identifier for the plot. :rtype: tuple """ return (self.__class__.__name__, self.cms_fpath, *settings_list)
[docs] def getExportData(self): """ Most panels export the same data whether export was selected from the plot panel or the main panel. Override this method to export different type of data when exporting from the parent panel, via the "Export Results..." button. """ return self.getDataForExport()
[docs] def exportToCSV(self): """ Export plot data to a CSV file """ fpath = filedialog.get_save_file_name( parent=self.window, caption="Save as CSV", filter="Comma-separated value (*.csv)") if not fpath: return rows = self.getDataForExport() with csv_unicode.writer_open(fpath) as fh: writer = csv.writer(fh) for row in rows: writer.writerow(row)
[docs] def exportToExcel(self): """ Export data to an .xls file """ fpath = filedialog.get_save_file_name(parent=self.window, caption="Save as Excel Workbook", filter='Excel (*.xls)') if not fpath: return wb = openpyxl.Workbook() ws = wb.active for row in self.getDataForExport(): ws.append(row) wb.save(fpath)
[docs] def saveImage(self): """ Save a .png file of the plot """ fpath = filedialog.get_save_file_name(parent=self.window, caption="Save Image", filter="PNG (*.png)") if not fpath: return view = self.view aspect_ratio = view.height() / view.width() # make sure image has high enough resolution for publication use. pixmap = QtGui.QPixmap(IMAGE_WIDTH, int(IMAGE_WIDTH * aspect_ratio)) pixmap.fill(Qt.transparent) painter = QtGui.QPainter(pixmap) view.render(painter) pixmap.save(fpath) painter.end()
[docs] def showContextMenu(self): menu = QtWidgets.QMenu(self.window) if self.view.isVisibleTo(self.parent()): menu.addAction(HIDE, self.view.hide) else: menu.addAction(SHOW, self.view.show) menu.addSeparator() menu.addAction(SAVE_IMG, self.saveImage) menu.addAction(EXPORT_CSV, self.exportToCSV) menu.addAction(EXPORT_EXCEL, self.exportToExcel) menu.addSeparator() menu.addAction(DELETE, self.deletePlot) menu.exec(QtGui.QCursor.pos())
[docs] def loadFromTask(self, task): """ Load in results from the given task. :param task: Task to get result data from. :type task: tasks.AbstractTask """ chart = self.chart() self._addSeries(task) # Explicitly save a reference to the series so it doesn't get destroyed (PANEL-18838) self._series = chart.series() self.formatPlotAxes()
[docs] def formatPlotAxes(self): """ Formats axes tick numbers and spacing depending on result data. """ analysis_mode = self.task.input.analysis_mode result = self.task.output.result chart = self.chart() chart.createDefaultAxes() axes = chart.axes() for axis in axes: if axis.orientation() == Qt.Orientation.Horizontal: axis.setTitleText('Time (ns)') else: unit_lbl = traj_plot_models.ANALYSIS_MODE_MAP[ analysis_mode].unit axis.setTitleText(unit_lbl) if analysis_mode == AnalysisMode.Torsion: axis.applyNiceNumbers() axis.setLabelFormat('%i') interstitial_values = max(result) - min(result) + 1 num_ticks = min(MAX_AXIS_TICKS, interstitial_values) axis.setTickCount(num_ticks) axis.setMin(-180) axis.setMax(180) else: axis.setLabelFormat('%.1f') _generateAxisSpecifications(result, axis)
def _addSeries(self, task): """ Adds a given task's output as a series to the plot. Also hooks up plot widget for interaction. :param task: Task to use for output :type task: traj_plot_model.AnalysisTask """ chart = self.chart() series = tplots.OutputSeries() series.setName(task.output.legend_name) map_type = task.output.legend_name self.series_map[map_type] = {} mode = task.input.analysis_mode # Create left/horizontal axis x_axis = QtCharts.QValueAxis() chart.addAxis(x_axis, Qt.AlignBottom) y_axis = tplots.OutputAxis() chart.addAxis(y_axis, Qt.AlignLeft) # Create residue information if mode is AnalysisMode.ResRMSF: self._addResidueSeries(series, x_axis) # Create result output series for idx, val in enumerate(task.output.result): if mode in traj_plot_models.INDEX_BASED_MODES: # Index is Atom/Res Index based x_val = idx else: # Index is frame # trj = self.entry_traj.trajectory x_val = trj[idx].time / 1000 if mode == AnalysisMode.Torsion and idx > 0: # Avoid near-vertical line when dihedral flips # from -179 to + 179 prev_val = task.output.result[idx - 1] if val > 178 and prev_val < -178 or val < -178 and prev_val > 178: chart.addSeries(series) series = tplots.OutputSeries() series.append(x_val, val) self.series_map[map_type][x_val] = val chart.addSeries(series) series.attachAxis(x_axis) series.attachAxis(y_axis) def _createAreaSeries(self, current_idx, max_idx, series_cls): """ Creates an area series for a secondary structure """ chart = self.chart() btm_line = QtCharts.QLineSeries(chart) top_line = QtCharts.QLineSeries(chart) left_idx = max(0, current_idx - 0.5) right_idx = min(current_idx + 0.5, max_idx) btm_line.append(left_idx, 0) btm_line.append(right_idx, 0) top_line.append(left_idx, 1) top_line.append(right_idx, 1) area_series = series_cls(btm_line, top_line) return area_series def _generateAtomLabels(self, anums): """ Generates atom labels for the given atoms from the trajectory. :param anums: list of atom ids :type anums: list(int) :return: List of atom labels for the given atoms. :rtype: list(str) """ st = self.entry_traj.cms_model.fsys_ct same_residue = False resnums = set(st.atom[anum].resnum for anum in anums) if len(resnums) == 1 and None not in resnums: same_residue = True atom_lbls = [] for idx, anum in enumerate(anums): atom = st.atom[anum] pdb_atom = atom.pdbname.strip() pdb_res = atom.pdbres.strip() pdb_resnum = atom.resnum ins_code = atom.inscode.strip() lbl = atom.name if pdb_atom and pdb_res: if pdb_res in structure.RESIDUE_MAP_3_TO_1_LETTER: pdb_res = structure.RESIDUE_MAP_3_TO_1_LETTER[pdb_res] if not same_residue or idx == 0: lbl = f"{pdb_res}{pdb_resnum}{ins_code}:{pdb_atom}" else: lbl = f"{pdb_atom}" elif pdb_atom: lbl = f"{pdb_atom}:{atom.index}" atom_lbls.append(lbl) return atom_lbls
[docs]class TrajectoryAnalysisPlotManager(AbstractTrajectoryPlotManager): """ Chart class used for graphs with an x-axis of frames """
[docs] def getPlotTitle(self): if self.collapsible_plot_widget: return self.collapsible_plot_widget.getPlotTitle() else: return super().getPlotTitle()
[docs] def getPlotType(self): """ Returns what type of plot this class uses; used for grouping export data. For TRAJECTORY grouping, each plot's data is added as a column in the same Excel sheet. """ return tplots.PlotDataType.TRAJECTORY
[docs] def getDataForExport(self): """ Return a list of row data to export to CSV or Excel. :return: Data to be exported :rtype: list(list) """ rows = [] header_row = ["Frame", "Time (ns)"] series_keys = list(self.series_map.keys()) series_titles = series_keys # If there is a single series title use plot widget title instead. if len(series_titles) == 1: plot_title = self.getPlotTitle() series_titles = [plot_title] header_row.extend(series_titles) rows.append(header_row) for time, idx in self.time_to_frame_map.items(): row = [idx, time] for series in series_keys: row.append(self.series_map[series][time]) rows.append(row) return rows
[docs]class WorkspaceMeasurementPlotManager(TrajectoryAnalysisPlotManager):
[docs] def __init__(self, panel, mode, measurement, centroid_asls): """ :param panel: Parent panel :type panel: QtWidget.QWidget :param mode: Analysis mode :type mode: AnalysisMode :param measurement: Tuple defining the measurement from MaestroHub :type measurement: tuple :param centroid_asls: List of ASLs for atoms involved in measurement. :type centroid_asls: list(str) """ super().__init__(panel) self.task = traj_plot_models.TrajectoryAnalysisTask() self.configureTask(self.task, mode, measurement, centroid_asls) self.setupView()
[docs] def configureTask(self, task, mode, measurement, centroid_asls): """ :param task: Task to configure. :type task: traj_plot_models.TrajectoryAnalysisTask :param mode: Analysis mode :type mode: AnalysisMode :param measurement: Tuple defining the measurement from MaestroHub :type measurement: tuple :param centroid_asls: List of ASLs for atoms involved in measurement. :type centroid_asls: list(str) """ super().configureTask(task) task.input.analysis_mode = mode fit_asl = ' OR '.join(centroid_asls) task.input.cms_fname = self.cms_fpath alist = [int(atom.split(':')[-1]) for atom in measurement[:-1]] lbl = self._generateAtomLabels(alist) task.input.atom_labels = lbl task.input.centroid_asl_list = centroid_asls task.input.fit_asl = fit_asl
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. task = self.task mode = task.input.analysis_mode atom_labels = task.input.atom_labels if mode == AnalysisMode.Distance: atom_label = ' to '.join(atom_labels) prefix = "Distance" elif mode == AnalysisMode.Angle: atom_label = ', '.join(atom_labels) prefix = "Angle" elif mode == AnalysisMode.Torsion: atom_label = ', '.join(atom_labels) prefix = "Dihedral" elif mode == AnalysisMode.PlanarAngle: atom_label = ', '.join(atom_labels) prefix = "Planar Angle" title = f"{prefix} {self.plot_number} - {atom_label}" tooltip = None # TODO: consider showing more information in a tooltip instead of # returning None return title, tooltip
[docs] def getSettingsHash(self): task = self.task return self.generateSettingsHash([ task.input.analysis_mode.name, task.input.atom_labels, task.input.centroid_asl_list ])
[docs]class CentroidDistancePlotManager(TrajectoryAnalysisPlotManager):
[docs] def __init__(self, panel, atom_sets): """ Create a plot for measuring distance between centroids of atom groups. :param atom_sets: List of tuples of atom IDs, per atom set. :type atom_sets: list(tuple(int)) NOTE: Disabled as of PANEL-20518 """ super().__init__(panel) self.task = traj_plot_models.TrajectoryAnalysisTask() self.configureTask(self.task, atom_sets) self.setupView()
[docs] def configureTask(self, task, atom_sets): """ Configure the plot task. :param atom_sets: List of tuples of atom IDs, per atom set. :type atom_sets: list(tuple(int)) """ super().configureTask(task) # Convert all sets to asl selections of the atom ids of those sets asl_sets = [ 'atom.n ' + ','.join(str(aid) for aid in aset) for aset in atom_sets ] fit_asl = ' OR '.join(asl_sets) task.input.analysis_mode = AnalysisMode.Distance if len( asl_sets) == 2 else AnalysisMode.Angle task.input.centroid_asl_list = asl_sets task.input.fit_asl = fit_asl
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. title = f"Distance {self.plot_number}" tooltip = None return title, tooltip
[docs]class AslPlotManager(TrajectoryAnalysisPlotManager):
[docs] def __init__(self, panel, analysis_mode, asl, anums): super().__init__(panel) self.task = traj_plot_models.TrajectoryAnalysisSubprocTask() self.configureTask(self.task, analysis_mode, asl, anums) self.setupView()
[docs] def configureTask(self, task, analysis_mode, asl, anums): super().configureTask(task) task.input.analysis_mode = analysis_mode task.input.additional_args = [asl] task.input.atom_numbers = list(anums) task.input.fit_asl = asl
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. task = self.task mode = task.input.analysis_mode if mode == AnalysisMode.Gyradius: prefix = "Radius of Gyration" elif mode == AnalysisMode.PolarSurfaceArea: prefix = "Polar Surface Area" elif mode == AnalysisMode.SolventAccessibleSurfaceArea: prefix = "Solvent Accessible Surface Area" elif mode == AnalysisMode.MolecularSurfaceArea: prefix = "Molecular Surface Area" else: raise ValueError(f"Unknown plot mode: {mode}") num_atoms = len(task.input.atom_numbers) title = f"{prefix} {self.plot_number} ({num_atoms} atoms)" tooltip = task.input.fit_asl return title, tooltip
[docs]class AbstractAdvancedTrajectoryPlotManager(AbstractTrajectoryPlotManager): """ Plot Manager for Advanced Plots (RMSF, Energy). These plots are stored as shortcuts at the bottom of the panel. """
[docs] def createShortcutWidget(self, plot_panel): """ Create and store reference to a shortcut widget for a given Advanced Plot :param plot_panel: Plot panel that shortcut will launch :type plot_panel: BaseAdvancedPlotPanel :return: Shortcut for the Advanced Plot :rtype: AdvancedPlotShortcut """ self.enableSeriesTracking() plot_title, _ = self.getInitialPlotTitleAndTooltip() self.chart().setTitle(plot_title) shortcut_title = f'{plot_panel.SHORTCUT_PREFIX} {self.plot_number}' shortcut_widget = AdvancedPlotShortcut(plot_panel, shortcut_title=shortcut_title, window_title=plot_title) return shortcut_widget
[docs]class Callout(QtWidgets.QGraphicsItem): """ A callout is a rounded rectangle that displays values for a point on a QChart """
[docs] def __init__(self, chart, series, pos, text_list): self.font = QtGui.QFont() self.chart = chart self.series = series self.pos = pos self.text_list = text_list self.bounding_rect = None super().__init__()
[docs] def paint(self, painter, option, widget): rect = self.boundingRect() light_blue = QtGui.QColor(220, 220, 255) painter.setBrush(light_blue) painter.drawRoundedRect(rect, 5, 5) text_rect = rect.adjusted(5, 5, 5, 5) text = '\n'.join(self.text_list) painter.drawText(text_rect, Qt.AlignLeft, text)
[docs] def generateBoundingRect(self): """ Creates a bounding rect based on text length/height and chart position """ # Generate metrics for callout text fm = QtGui.QFontMetrics(self.font) buffer = 10 text_width = max(*[fm.width(text) for text in self.text_list], 30) + buffer text_height = max(fm.height() * len(self.text_list), 30) + buffer # Create rectangle, flipping orientation if at risk of escaping chart pt = self.chart.mapToPosition(self.pos, self.series) x0, y0 = pt.x(), pt.y() x1 = x0 - text_width if x1 < 0: x1 = x0 + text_width y1 = y0 - text_height if y1 < 0: y1 = y0 + text_height pt0 = QtCore.QPointF(min(x0, x1), min(y0, y1)) pt1 = QtCore.QPointF(max(x0, x1), max(y0, y1)) rect = QtCore.QRectF(pt0, pt1) return rect
[docs] def boundingRect(self): if self.bounding_rect is None: self.bounding_rect = self.generateBoundingRect() return self.bounding_rect
[docs]class CollapsiblePlot(QtWidgets.QWidget): """ This class defines a collapsible plot. The widget has an area for title text, a 'collapse' button and a 'close' button. """
[docs] def __init__(self, parent=None, system_title='', plot_title='', plot=None, tooltip=None): """ :param system_title: System title for the widget :type system_title: str :param plot_title: Title to set for this title bar :type plot_title: str :param plot: Plot to set in the collapsible area :type plot: `AbstractTrajectoryPlotManager` :param tooltip: Optional tooltip for the title :type tooltip: str """ super().__init__(parent=parent) self.ui = collapsible_plot_ui.Ui_Form() self.ui.setupUi(self) self.ui.collapse_btn.clicked.connect(self.onCollapseButtonClicked) self.ui.close_btn.clicked.connect(plot.deletePlot) self.ui.system_title_label.setText(system_title) self.ui.plot_title_le.setText(plot_title) if len(plot_title) >= 45: if tooltip is None: tooltip = plot_title else: tooltip = plot_title + ': ' + tooltip plot_title = plot_title[:42] + '...' self.ui.plot_title_le.setText(plot_title) if tooltip: tooltip += '<br><i>Double-click to edit</i>' self.ui.plot_title_le.setToolTip(tooltip) self.plot = plot self.system_title = system_title self.fit_asl = None assert plot is not None plot.view.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) self.ui.widget_layout.addWidget(plot.view) plot.setParent(self) plot.view.setVisible(True)
[docs] def onPlotTitleChanged(self): """ For multi-interaction plots, title can change as new tasks are completed; update the title. """ plot_title, tooltip = self.plot.getInitialPlotTitleAndTooltip() self.ui.plot_title_le.setText(plot_title) self.ui.plot_title_le.setToolTip(tooltip)
[docs] def getPlotTitle(self): """ Returns plot title. :return: plot title :rtype: str """ return self.ui.plot_title_le.text()
[docs] def onCollapseButtonClicked(self): """ Collapse or expand the plot depending on its current state. """ view = self.plot.view already_visible = view.isVisibleTo(self.parent()) view.setVisible(not already_visible)
[docs] def mousePressEvent(self, event): if event.button() == Qt.RightButton: self.plot.showContextMenu() super().mousePressEvent(event)
############################# # ADVANCED PLOTS AND SHORTCUTS #############################
[docs]class BaseAdvancedPlotPanel(basewidgets.Panel): """ Base class for plot panels that get opened via shortcuts in the "Advanced Plots" section of the main plots panel. """ """ :cvar closeRequested: Signal emitted when the widget is closed. :type closeRequested: `QtCore.pyqtSignal()` """ closeRequested = QtCore.pyqtSignal() def _showContextMenu(self): menu = QtWidgets.QMenu(self) menu.addAction(SAVE_IMG, self.plot.saveImage) menu.addAction(EXPORT_CSV, self.plot.exportToCSV) menu.addAction(EXPORT_EXCEL, self.plot.exportToExcel) menu.addSeparator() menu.addAction(DELETE, lambda: self.closeRequested.emit()) menu.exec(QtGui.QCursor.pos())
[docs]class AdvancedPlotShortcut(basewidgets.BaseWidget): """ Shortcut icon that opens an advanced plots (RMSF and Energy plots). """ ui_module = shortcut_ui
[docs] def __init__(self, plot_panel, shortcut_title='', window_title='', parent=None): super().__init__(parent) self.plot = plot_panel.plot self.plot_panel = plot_panel self.plot_panel.setWindowTitle(window_title) self.plot_panel.closeRequested.connect(self.plot.deletePlot) icon = QtGui.QPixmap(":/trajectory_gui_dir/icons/adv_plot.png") self.ui.icon_lbl.setPixmap(icon) self.shortcut_title = shortcut_title self.ui.shortcut_lbl.setText(self.shortcut_title)
[docs] def mousePressEvent(self, event): super().mousePressEvent(event) if event.button() == QtCore.Qt.LeftButton: self.plot_panel.show() self.plot_panel.raise_() if event.button() == QtCore.Qt.RightButton: self._showContextMenu()
def _showContextMenu(self): menu = QtWidgets.QMenu(self) menu.addAction(VIEW_PLOT) menu.addSeparator() menu.addAction(DELETE) res = menu.exec(QtGui.QCursor.pos()) if not res: return res_txt = res.text() if res_txt == VIEW_PLOT: self.plot_panel.show() self.plot_panel.raise_() elif res_txt == DELETE: self.plot.deletePlot.emit()
[docs] def getPlotTitle(self): return self.shortcut_title
[docs] def close(self): """ Close and remove this widget. """ self.plot_panel.close() super().close()
[docs]class ShortcutRow(basewidgets.BaseWidget): """ This class represents a row of advanced plot shortcuts """
[docs] def initLayOut(self): super().initLayOut() spacer = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) self.row_layout = QtWidgets.QHBoxLayout() self.row_layout.addItem(spacer) self.main_layout.addLayout(self.row_layout)
[docs] def hasSpace(self): """ Returns whether the shortcut row has space for another widget """ return self.widgetCount() < MAX_SHORTCUTS_IN_ROW
[docs] def addWidget(self, wdg): self.row_layout.insertWidget(self.widgetCount(), wdg)
[docs] def widgetCount(self): return self.row_layout.count() - 1
############################# # Custom Series and Axes #############################
[docs]class OutputAxis(QtCharts.QValueAxis): pass
[docs]class BFactorAxis(QtCharts.QValueAxis): pass
[docs]class SecondaryStructureAxis(QtCharts.QValueAxis): pass
[docs]class OutputSeries(QtCharts.QLineSeries): pass
[docs]class BFactorSeries(QtCharts.QLineSeries): pass
[docs]class SecondaryStructureHelixSeries(QtCharts.QAreaSeries): pass
[docs]class SecondaryStructureStrandSeries(QtCharts.QAreaSeries): pass
[docs]class EnergyPlotPanel(BaseAdvancedPlotPanel): """ Plot for energy analysis. """ ui_module = energy_plot_ui model_class = traj_plot_models.EnergyPlotModel SHORTCUT_PREFIX = 'Energy'
[docs] def __init__(self, plot_view, parent=None): self.plot = plot_view self.chart = plot_view.chart() super().__init__(parent) self.setWindowTitle('Review Energy Plots')
[docs] def initSetUp(self): super().initSetUp() self.ui.close_btn.clicked.connect(self.close) self.ui.plot_layout.addWidget(self.plot.view) hheader = self.ui.sets_table.view.horizontalHeader() hheader.setStretchLastSection(True) hheader.hide() pixmap = QtGui.QPixmap(icons.MORE_ACTIONS_DB) self.ui.options_btn.setIcon(QtGui.QIcon(pixmap)) self.ui.options_btn.setIconSize(QtCore.QSize(30, 15)) self.ui.options_btn.setStyleSheet('border: none;') self.ui.options_btn.clicked.connect(self._showContextMenu) self.resize(self.width(), 800) # make taller
[docs] def initFinalize(self): super().initFinalize() # Populate the sets PLPTableWidget with sets from our model: sets = [] for i, name in enumerate(self.plot.task.input.set_names): row = traj_plot_models.SetRow() row.name = name sets.append(row) self.model.sets = sets spec = self.ui.sets_table.makeAutoSpec(self.model.sets) self.ui.sets_table.setSpec(spec) self.ui.sets_table.setPLP(self.model.sets) # By default select all sets: self.model.selected_sets = [s for s in sets]
[docs] def defineMappings(self): M = self.model_class ui = self.ui return [ (ui.sets_table, M.sets), (ui.sets_table.selection_target, M.selected_sets), (ui.exclude_self_terms_cb, M.exclude_self_terms), (ui.coulomb_cb, M.coulomb), (ui.van_der_waals_cb, M.van_der_waals), (ui.bond_cb, M.bond), (ui.angle_cb, M.angle), (ui.dihedral_cb, M.dihedral), ] # yapf: disable
[docs] def getSignalsAndSlots(self, model): return [ (model.selected_setsChanged, self.updatePlotValues), (model.exclude_self_termsChanged, self.updatePlotValues), (model.coulombChanged, self.updatePlotValues), (model.van_der_waalsChanged, self.updatePlotValues), (model.bondChanged, self.updatePlotValues), (model.angleChanged, self.updatePlotValues), (model.dihedralChanged, self.updatePlotValues), ] # yapf: disable
def _updateEnergyToggles(self): """ Update check box states for different energy terms based on current UI state. """ ui = self.ui disable = len(self.model.selected_sets) == 1 or \ (not ui.coulomb_cb.isChecked() and not ui.van_der_waals_cb.isChecked()) self._enableEnergyToggle(ui.exclude_self_terms_cb, not disable) enable = not self.model.exclude_self_terms self_term_toggles = [ui.bond_cb, ui.angle_cb, ui.dihedral_cb] for cb in self_term_toggles: self._enableEnergyToggle(cb, enable) def _enableEnergyToggle(self, check_box, enable): """ Enable or disable given check box depending on the 'enable' argument. :param check_box: energy term check box :type check_box: QtWidgets.QCheckBox :param enable: defines whether check box should be enabled or not :type enable: bool """ check_box.setEnabled(enable) if not enable: check_box.setChecked(False)
[docs] def updatePlotValues(self): """ Slot for updating the chart based on current UI selection. """ # Update state of energy check boxes. self._updateEnergyToggles() m = self.model term_name_map = { 'Coulomb': m.coulomb, 'van der Waals': m.van_der_waals, 'Bond': m.bond, 'Angle': m.angle, 'Dihedral': m.dihedral, } terms_used = [name for name, param in term_name_map.items() if param] num_terms_used = len(terms_used) if num_terms_used == 0: pass elif num_terms_used == len(term_name_map): term_str = 'Total Energy' elif num_terms_used == 1: term_str = terms_used[0] + ' Energy' elif num_terms_used == 2: term_str = ' and '.join(terms_used) + ' Energies' else: term_str = ', '.join( terms_used[:-1]) + ' and ' + terms_used[-1] + ' Energies' if num_terms_used == 0 or not m.selected_sets: title = '' else: title = ' - '.join((setrow.name for setrow in m.selected_sets)) if m.exclude_self_terms: title += ' Interactions' title += ': ' + term_str self.ui.plot_title_lbl.setText(title) self.plot.setPlotData(self.getEnergyValues())
[docs] def getEnergyValues(self): """ Return the energy values based on the current panel settings. :return: """ m = self.model use_sets = [] for i, set in enumerate(m.sets): if set in m.selected_sets: result_id = f'sel_{i:03}' use_sets.append(result_id) if not use_sets: return None checked_by_term = { 'elec': m.coulomb, 'vdw': m.van_der_waals, 'stretch': m.bond, 'angle': m.angle, 'dihedral': m.dihedral, } use_terms = [ name for name, checked in checked_by_term.items() if checked ] if not use_terms: return None include_self = not m.exclude_self_terms return energy_plots.sum_results(self.plot.results, use_sets, use_terms, include_self)
[docs]class EnergyPlotManager(AbstractAdvancedTrajectoryPlotManager): """ Chart class for energy matrix data. The plot data will be populated by the EnergyPlotPanel. """ PANEL_CLASS = EnergyPlotPanel
[docs] def __init__(self, panel, cfg_file, sets): """ Initialize the energy plot. :param panel: Parent panel :type panel: QtWidgets.QWidget. :param cfg_file: Path to the cfg file. :type cfg_file: str :param sets: Dict of sets where keys are set names and values are ASLs. :type sets: dict """ super().__init__(panel) self.results = None self.frame_times = None self.energies = None self.task = traj_plot_models.TrajectoryEnergyJobTask() self.configureTask(self.task, cfg_file, sets) self.setupView()
[docs] def configureTask(self, task, cfg_file, sets): """ Configure the energy task. :param task: Task to configure :type task: traj_plot_models.TrajectoryEnergyJobTask :param cfg_file: Path to the cfg file :type cfg_file: str :param sets: Dict of sets where keys are set names and values are ASLs. :type sets: dict """ super().configureTask(task) task.input.analysis_mode = AnalysisMode.Energy job_config = self.panel.job_settings_dlg.model host = job_config.host_settings.host task.job_config.host_settings.host = host task.name = 'desmond_energy' task.input.trj_dir = self.trj_dir task.input.cfg_fname = cfg_file task.input.set_names = list(sets.keys()) task.input.set_asls = list(sets.values())
# TODO: Consider setting fit_asl
[docs] def getPlotType(self): return PlotDataType.ENERGY
[docs] def enableSeriesTracking(self): pass
[docs] def loadFromTask(self, task): """ Load in results from the given task. :param task: Task to get result data from. :type task: traj_plot_models.TrajectoryEnergyJobTask """ results, set_asls, frame_times = \ energy_plots.parse_results_file(task.output.results_file) if results.val is None: # Task failed to complete. raise RuntimeError( 'Energy analysis task failed to produce results.') self.set_asls = set_asls # used for export self.results = results # Convert picoseconds to nanoseconds: self.frame_times = [time / 1000.0 for time in frame_times]
# NOTE: Series will be added dynamically when EnergyPlotPanel is # initialized.
[docs] def setPlotData(self, energies): """ Set self.energies array to the given data, and re-draw the chart. """ self.energies = energies chart = self.chart() chart.removeAllSeries() series = tplots.OutputSeries() if not chart.axes(): # Create left/horizontal axis self.x_axis = QtCharts.QValueAxis() self.x_axis.setTitleText('Time (ns)') chart.addAxis(self.x_axis, Qt.AlignBottom) series.attachAxis(self.x_axis) self.y_axis = tplots.OutputAxis() self.y_axis.setLabelFormat('%.0f') self.y_axis.setTitleText('Energy (kCal/mol)') chart.addAxis(self.y_axis, Qt.AlignLeft) series.attachAxis(self.y_axis) if self.energies is None: # No sets or terms selected return # Add data series to the plot: for x_time, y_energy in zip(self.frame_times, self.energies): series.append(x_time, y_energy) _generateAxisSpecifications(self.energies, self.y_axis) self.x_axis.setMin(min(self.frame_times)) self.x_axis.setMax(max(self.frame_times)) chart.addSeries(series)
[docs] def getDataForExport(self): """ Return a list of row data to export to Excel or CSV. Used by the export menu in the plot sub-window. """ rows = [] header_row = ['Frame', 'Time (ns)', 'Energy (kCal/mol)'] rows.append(header_row) for idx, (time, energy) in enumerate(zip(self.frame_times, self.energies), start=1): row = [idx, time, energy] rows.append(row) return rows
[docs] def getExportData(self): """ Return a list of row data to export to Excel. Used by the "Export Results..." button of the parent plots panels. :return: Data to be exported :rtype: list(list) """ all_values_dict = energy_plots.format_results_by_frame(self.results) rows = [] # Print ASLs for each set, above the header row: for i, asl in enumerate(self.set_asls): rows.append([f"sel_{i:03}", asl]) # Header row: header = ["Frame", "Time (ns)"] + list(all_values_dict.keys()) rows.append(header) # Energies, one row per frame: energy_lists = all_values_dict.values() for idx, time in enumerate(self.frame_times): energies_for_frame = [energies[idx] for energies in energy_lists] row = [idx + 1, time] + energies_for_frame rows.append(row) return rows
[docs] def getInitialPlotTitleAndTooltip(self): task = self.task set_names = ', '.join(task.input.set_names) tooltip = f'Sets: {set_names}' prefix = "Review Energy Plots - Energy" return f"{prefix} {self.plot_number}", tooltip
[docs] def getSettingsHash(self): task = self.task return self.generateSettingsHash( [task.input.set_names, task.input.set_asls])
[docs]class PlanarAngleAnalysisPlot(TrajectoryAnalysisPlotManager):
[docs] def __init__(self, panel, alist): super().__init__(panel) task = traj_plot_models.TrajectoryAnalysisTask() self.task = task self.configureTask(task, alist) self.setupView()
[docs] def configureTask(self, task, alist): super().configureTask(task) task.input.analysis_mode = AnalysisMode.PlanarAngle anums_str = ','.join(map(str, alist)) fit_asl = f'atom.n {anums_str}' task.input.additional_args = alist task.input.atom_numbers = alist atom_labels = self._generateAtomLabels(alist) task.input.atom_labels = atom_labels task.input.fit_asl = fit_asl