Source code for schrodinger.trajectory.trajectory_gui_dir.rmsd_plots

"""
File containing RMSD plot code used in the Trajectory Plots GUI
"""
import os

from schrodinger.models import mappers
from schrodinger.Qt import QtGui, QtCore
from schrodinger.Qt.QtCore import Qt
from schrodinger.infra import mm
from schrodinger.structutils import analyze
from schrodinger import structure
from schrodinger.ui.qt import atomselector
from schrodinger.ui.qt import basewidgets
from schrodinger.ui.qt import filedialog

from . import advanced_plot_ui
from . import rmsd_settings_dialog_ui
from . import traj_plot_models
from . import plots as tplots

from schrodinger import get_maestro

maestro = get_maestro()

try:
    from schrodinger.application.desmond.packages import analysis
    from schrodinger.application.desmond.packages import topo
except ImportError:
    analysis = None
    topo = None

ENERGY_PLOT_EXISTS_WARNING = 'Energy plot with this selection already exists.'

# Colors
RMSF_COLOR = QtGui.QColor.fromRgb(158, 31, 222)
TEMP_COLOR = QtGui.QColor.fromRgb(207, 105, 31)
HELIX_COLOR = QtGui.QColor.fromRgb(253, 236, 232)
STRAND_COLOR = QtGui.QColor.fromRgb(229, 246, 250)

B_FACTOR_SERIES = 'b_factor_series'

SECONDARY_STRUCTURE_THRESHOLD = 0.7
SERIES_WIDTH = 1


def _is_series_ss(series):
    """
    Returns whether series is a series representing a Secondary Structure

    :param series: Series to check
    :type series: QLineSeries
    """
    return type(series) in [
        tplots.SecondaryStructureStrandSeries,
        tplots.SecondaryStructureHelixSeries
    ]


[docs]class RmsfPlotPanel(tplots.BaseAdvancedPlotPanel): """ Advanced plots are for time-series data (e.x. RMSF) """ ui_module = advanced_plot_ui model_class = traj_plot_models.RmsfPlotModel SHORTCUT_PREFIX = 'RMSF'
[docs] def __init__(self, plot, parent=None): self.plot = plot self.chart = plot.chart() self.mode = plot.task.input.analysis_mode super().__init__(parent)
[docs] def initSetUp(self): super().initSetUp() residue_mode = self.mode is traj_plot_models.AnalysisMode.ResRMSF self.ui.residue_info_wdg.setVisible(residue_mode) self.ui.residue_options_wdg.setVisible(residue_mode) self.ui.plot_layout.addWidget(self.plot.view) self.ui.options_link.clicked.connect(self._onOptionsToggle) self.ui.close_btn.clicked.connect(self.close)
[docs] def defineMappings(self): M = self.model_class ui = self.ui b_factor_trg = mappers.TargetSpec(ui.pdb_b_factor_cb, slot=self._onBFactorToggle) ss_trg = mappers.TargetSpec(ui.secondary_st_color_cb, slot=self._onSecondaryStructureToggle) return [ (ss_trg, M.secondary_structure_colors), (b_factor_trg, M.b_factor_plot), ] # yapf: disable
def _onBFactorToggle(self): visible = self.model.b_factor_plot if self.mode is traj_plot_models.AnalysisMode.ResRMSF: for series in self.chart.series(): if type(series) == tplots.BFactorSeries: series.setVisible(visible) for axis in self.chart.axes(): if type(axis) == tplots.BFactorAxis: axis.setVisible(visible) def _onSecondaryStructureToggle(self): visible = self.model.secondary_structure_colors if self.mode is traj_plot_models.AnalysisMode.ResRMSF: for series in self.chart.series(): if _is_series_ss(series): series.setVisible(visible) def _onOptionsToggle(self): visible = not self.ui.residue_options_wdg.isVisible() self.ui.residue_options_wdg.setVisible(visible)
[docs] def mousePressEvent(self, event): if event.button() == Qt.RightButton: self._showContextMenu() super().mousePressEvent(event)
[docs]class BaseRmsfPlotManager(tplots.AbstractAdvancedTrajectoryPlotManager): """ Class containing RMSD plot related methods. """ ANALYSIS_MODE = None # must be defined by subclasses
[docs] def __init__(self, panel, cms_model, aids, fit_aids, fit_ref_pos): super().__init__(panel) mode = self.ANALYSIS_MODE self.task = traj_plot_models.TrajectoryAnalysisSubprocTask() self.configureTask(self.task, mode, cms_model, aids, fit_aids, fit_ref_pos) self.setupView() self.initializeCallouts() self._callout = None self._series = None
[docs] def getSettingsHash(self): task = self.task return self.generateSettingsHash([ task.input.analysis_mode.name, task.input.additional_args, task.input.additional_kwargs ])
def _generateRMSFAtomLabels(self, anums): """ Generates atom labels for RMSF graphs based on the trajectory CT. :param anums: list of atom ids :type anums: list(int) """ st = self.entry_traj.cms_model.fsys_ct atom_lbls = [] for anum in anums: atom = st.atom[anum] chain = atom.chain.strip() pdb_atom = atom.pdbname.strip() pdb_res = atom.pdbres.strip() pdb_resnum = atom.resnum lbl = atom.name if pdb_atom and pdb_res: if pdb_res in structure.RESIDUE_MAP_3_TO_1_LETTER: pdb_res = structure.RESIDUE_MAP_3_TO_1_LETTER[pdb_res] lbl = f"{atom.name} ({chain}: {pdb_res} {pdb_resnum})" atom_lbls.append(lbl) return atom_lbls
[docs] def configureTask(self, task, mode, cms_model, aids, fit_aids, fit_ref_pos): """ Configure the given task based on specified settings. :param task: Task to configure. :type task: TrajectoryAnalysisSubprocTask :param mode: Analysis mode :type mode: traj_plot_models.AnalysisMode :param cms_model: CMS model with current frame selected. :rtype: cms.Cms :param aids: AIDs for the selected Workspace atoms. :type aids: list(int) :param fit_aids: AIDS for the reference atoms. :type fit_aids: list(int) :param fit_ref_pos: Coordinates for the reference atoms. :type fit_ref_pos: list((float, float, float)) """ # Calculate arguments for RMSD task. Structure is from current frame super().configureTask(task) ref_pos = [cms_model.atom[aid].xyz for aid in aids] traj_fit_asl = "atom.n " + ",".join(map(str, aids)) if mode in traj_plot_models.INDEX_BASED_MODES: additional_args = [aids, fit_aids, fit_ref_pos] additional_kwargs = {} else: additional_args = [aids, ref_pos] additional_kwargs = { 'fit_aids': fit_aids, 'fit_ref_pos': fit_ref_pos } task.input.analysis_mode = mode task.input.additional_args = additional_args task.input.atom_labels = self._generateRMSFAtomLabels(aids) task.input.atom_numbers = list(aids) task.input.fit_asl = traj_fit_asl if fit_aids: task.input.additional_kwargs = additional_kwargs if mode is traj_plot_models.AnalysisMode.ResRMSF: b_factors = {} residue_atoms = {} for aid in aids: atom = cms_model.atom[aid] res = atom.getResidue() res_lbl = analysis._prot_atom_label(cms_model, aid, True) b_factors[res_lbl] = res.temperature_factor residue_atoms[res_lbl] = res.getAtomIndices() task.input.b_factors = b_factors task.input.residue_atoms = residue_atoms
[docs] def initializeCallouts(self): """ Initializes the plot to accept events """ chart = self.chart() chart.setAcceptHoverEvents(True) self.view.scene().addItem(chart)
[docs] def enableSeriesTracking(self): chart = self.chart() series = chart.series() for line in series: if not _is_series_ss(series): line.hovered.connect(self.onHover) # Explicitly save a reference to the series so it doesn't get destroyed (PANEL-18838) self._series = self.chart().series()
[docs] def generateCalloutText(self, pos): rmsf_info = temp_info = '' callout_text_list = [] data_x = round(pos.x()) for series in self.chart().series(): series_type = type(series) if not _is_series_ss(series): data_point = series.at(data_x) if series_type == tplots.OutputSeries: rmsf_info = f'RMSF = {data_point.y():.2f} Å' if series_type == tplots.BFactorSeries and series.isVisible(): temp_info = f'B Factor = {round(data_point.y(), 1)}' if self.task.input.analysis_mode == traj_plot_models.AnalysisMode.AtomRMSF: atom_info = self.task.input.atom_labels[data_x] callout_text_list = [atom_info, rmsf_info] elif self.task.input.analysis_mode == traj_plot_models.AnalysisMode.ResRMSF: res_info = self.task.output.residue_info.residue_names[data_x] callout_text_list = [res_info, rmsf_info] if temp_info: callout_text_list.insert(1, temp_info) return callout_text_list
[docs] def onHover(self, pos, enter): series = self.sender() if self._callout is None: text_list = self.generateCalloutText(pos) callout = tplots.Callout(self.chart(), series, pos, text_list) callout.setZValue(1) self._callout = callout if enter: self.view.scene().addItem(self._callout) else: self.view.scene().removeItem(self._callout) self._callout = None
[docs] def getPlotType(self): return tplots.PlotDataType.RMSF
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. task = self.task num_atoms = len(task.input.atom_numbers) title = f"RMSD {self.plot_number} ({num_atoms} atoms)" tooltip = task.input.fit_asl return title, tooltip
[docs]class RmsdByFramePlotManager(BaseRmsfPlotManager, tplots.TrajectoryAnalysisPlotManager): ANALYSIS_MODE = traj_plot_models.AnalysisMode.RMSD PANEL_CLASS = None # This plot has no special settings """ RMSD calculation where x-axis shows the frames. Inherits from TrajectoryAnalysisPlotManager for the getDataForExport() method """
[docs] def formatPlotAxes(self): """ Formats axes tick numbers and spacing depending on result data. """ result = self.task.output.result chart = self.chart() chart.createDefaultAxes() axes = chart.axes() for axis in axes: if axis.orientation() == Qt.Orientation.Horizontal: axis.setTitleText('Time (ns)') else: axis.setLabelFormat('%.1f') tplots._generateAxisSpecifications(result, axis) axis.setMin(0) axis.setTitleText('RMSD (Å)')
[docs]class AtomRmsfPlotManager(BaseRmsfPlotManager): """ RMSD calculation where x-axis of the plot has atom numbers. """ PANEL_CLASS = RmsfPlotPanel ANALYSIS_MODE = traj_plot_models.AnalysisMode.AtomRMSF
[docs] def getDataForExport(self): """ Return a list of row data to export to CSV or Excel. :return: Data to be exported :rtype: list(list) """ header_row = ['Atom Index'] series_titles = self.series_map.keys() header_row.extend(series_titles) rows = [header_row] for series in series_titles: for idx, (key, value) in enumerate(self.series_map[series].items()): if idx >= len(rows) - 1: rows.append([key]) rows[idx + 1].append(value) return rows
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. task = self.task num_atoms = len(task.input.atom_numbers) title = f"Atom RMSF {self.plot_number} ({num_atoms} atoms)" tooltip = task.input.fit_asl return title, tooltip
[docs] def formatPlotAxes(self): """ Formats axes tick numbers and spacing depending on result data. """ result = self.task.output.result chart = self.chart() chart.createDefaultAxes() axes = chart.axes() for axis in axes: if axis.orientation() == Qt.Orientation.Horizontal: axis.setTitleText('Atom Index') axis.setLabelFormat('%i') else: axis.setLabelFormat('%.1f') tplots._generateAxisSpecifications(result, axis) axis.setMin(0) axis.setTitleText('RMSF (Å)')
[docs] def onPlotClicked(self, value): """ Fire a signal to show ASL of selection on left click """ # Over-ride base method, because x-axis represents atoms # and not trajectory frames. data_x = round(value.x()) asl = f'atom.n {self.task.input.atom_numbers[data_x]}' self.displayAsl.emit(asl, self.entry_traj.eid)
[docs]class ResidueRmsfPlotManager(BaseRmsfPlotManager): """ RMSD calculation where x-axis of the plot has residue numbers. """ PANEL_CLASS = RmsfPlotPanel ANALYSIS_MODE = traj_plot_models.AnalysisMode.ResRMSF def _addResidueSeries(self, series, x_axis): """ Adds residue series information (B factor and SSA info) to the plot :param x_axis: X Axis to affix series to :type x_axis: QtCharts.QValueAxis """ series.setColor(RMSF_COLOR) task = self.task chart = self.chart() res_names = task.output.residue_info.residue_names # Secondary Structure ss_axis = tplots.SecondaryStructureAxis() chart.addAxis(ss_axis, Qt.AlignRight) ss_axis.setMin(0) ss_axis.setMax(1) ss_axis.hide() for idx, sec_st in enumerate( task.output.residue_info.secondary_structures): threshold = SECONDARY_STRUCTURE_THRESHOLD * len(sec_st) helix_count = sec_st.count(mm.MMCT_SS_HELIX) strand_count = sec_st.count(mm.MMCT_SS_STRAND) if helix_count >= threshold: area_series = self._createAreaSeries( idx, len(res_names), tplots.SecondaryStructureHelixSeries) elif strand_count >= threshold: area_series = self._createAreaSeries( idx, len(res_names), tplots.SecondaryStructureStrandSeries) else: continue chart.addSeries(area_series) area_series.attachAxis(ss_axis) area_series.attachAxis(x_axis) # B Factor y_axis = tplots.BFactorAxis() chart.addAxis(y_axis, Qt.AlignRight) series = tplots.BFactorSeries() series.setName(B_FACTOR_SERIES) for idx, res_name in enumerate(res_names): x = idx y = task.input.b_factors.get(res_name, None) if y is not None: series.append(x, y) chart.addSeries(series) series.attachAxis(x_axis) series.attachAxis(y_axis)
[docs] def getDataForExport(self): """ Return a list of row data to export to CSV or Excel. :return: Data to be exported :rtype: list(list) """ res_names = self.task.output.residue_info.residue_names # Residue plots always have a single series assert len(self.series_map) == 1 plot_title, values_dict = next(iter(self.series_map.items())) header_row = ['Residue Index', 'Residue', plot_title] rows = [header_row] for (key, value), res_name in zip(values_dict.items(), res_names): row = [key, res_name, value] rows.append(row) return rows
[docs] def formatPlotAxes(self): """ Formats the axes and colors series on a residue plot. """ task = self.task chart = self.chart() axes_info = { tplots.BFactorAxis: (tplots.BFactorSeries, 'B Factor', TEMP_COLOR), tplots.OutputAxis: (tplots.OutputSeries, traj_plot_models.ANGSTROMS_RMSF, RMSF_COLOR), } for axis in chart.axes(): axis.setGridLineVisible(False) if axis.orientation() == Qt.Orientation.Horizontal: unit_lbl = 'Residue Index' axis.setLabelFormat('%i') axis.setMin(0) axis.setMax(len(task.output.result) - 1) elif type(axis) in axes_info: series_type, title, color, = axes_info[type(axis)] hex_color = color.name() for chart_series in chart.series(): if type(chart_series) == series_type: series = chart_series pen = series.pen() pen.setWidth(SERIES_WIDTH) series.setPen(pen) vals = [(pt.x(), pt.y()) for pt in series.pointsVector()] unit_lbl = f'<span style="color: {hex_color};">{title}</span>' series.setColor(color) axis.setLinePenColor(color) tplots._generateAxisSpecifications([y for _, y in vals], axis) axis.setLabelFormat('%.1f') else: # Otherwise, we have a secondary structure series for series in chart.series(): series_type = type(series) if series_type == tplots.SecondaryStructureHelixSeries: color = HELIX_COLOR elif series_type == tplots.SecondaryStructureStrandSeries: color = STRAND_COLOR else: continue series.setColor(color) series.setBorderColor(color) axis.setTitleText(unit_lbl)
[docs] def getInitialPlotTitleAndTooltip(self): # See base method for documentation. task = self.task num_res = len(task.output.residue_info.residue_names) title = f"Protein RSMF {self.plot_number} ({num_res} residues)" tooltip = task.input.fit_asl return title, tooltip
[docs] def onPlotClicked(self, value): """ Fire a signal to show ASL of selection on left click """ # Over-ride base method, because x-axis represents residues # and not trajectory frames. data_x = round(value.x()) res_lbl = self.task.output.residue_info.residue_names[data_x] atoms = self.task.input.residue_atoms[res_lbl] asl = f"atom.n {','.join(map(str, atoms))}" self.displayAsl.emit(asl, self.entry_traj.eid)
[docs]class RMSDDialog(basewidgets.BaseWidget): """ RMSD/RMSF Dialog for specifying reference frame or structure for the trajectory plot gui. """ ui_module = rmsd_settings_dialog_ui trajectoryChanged = QtCore.pyqtSignal(int)
[docs] def initSetOptions(self): super().initSetOptions() self.std_btn_specs = { self.StdBtn.Ok: None, self.StdBtn.Cancel: None, self.StdBtn.Reset: self.reset }
[docs] def initSetUp(self): super().initSetUp() self._structure = None self.setWindowTitle('RMSD / RMSF Settings') self._updateReferenceComponents() self.ui.frame_rb.clicked.connect(self._updateReferenceComponents) self.ui.structure_rb.clicked.connect(self._updateReferenceComponents) self.ui.browse_btn.clicked.connect(self._onBrowseBtnClicked) self.ui.load_from_project_entry_pb.clicked.connect( self._onLoadFromPTBtnClicked) self.ui.load_selection_pb.clicked.connect( self._onLoadSelectionBtnClicked) self.trajectoryChanged.connect(self._onTrajectoryChange) self.atom_selector = atomselector.AtomSelector(self, show_pick=False, show_selection=False, show_plus=True) self.atom_selector.setAsl('all') self._updateSuperimposeComponents() self.ui.superimpose_group.buttonToggled.connect( self._updateSuperimposeComponents) self.ui.all_selector_layout.addWidget(self.atom_selector)
def _updateReferenceComponents(self): """ Enables / disables components based on what reference radio button is selected. """ frame_enabled = self.ui.frame_rb.isChecked() structure_enabled = not frame_enabled self.ui.frame_sb.setEnabled(frame_enabled) self.ui.load_from_file_rb.setEnabled(structure_enabled) self.ui.browse_btn.setEnabled(structure_enabled) self.ui.project_entry_rb.setEnabled(structure_enabled) self.ui.load_from_project_entry_pb.setEnabled(structure_enabled) self.ui.loaded_structure_lbl.setEnabled(structure_enabled) def _updateSuperimposeComponents(self): enable_atom_selector = self.ui.other_atoms_rb.isChecked() self.atom_selector.setEnabled(enable_atom_selector) self.ui.load_selection_pb.setEnabled(enable_atom_selector)
[docs] def getStructure(self): """ Returns structure to calculate RMSD against based on settings """ return self._structure
[docs] def getSuperimposeAsl(self): """ Get the superimpose asl string in the dialog """ if self.ui.other_atoms_rb.isChecked(): return self.atom_selector.getAsl() elif self.ui.selected_atoms_rb.isChecked(): return self.getSelectedAsl()
[docs] def getFrameNumber(self): return self.ui.frame_sb.value()
def _onBrowseBtnClicked(self): filter_str = ';;'.join([ 'Maestro File (*.mae *.mae.gz *.maegz)', 'MDL SD (*.sd *.sdf *.mol *.sdfgz *.sdf.gz)', 'MOL2 (*.mol2)', 'PDB (*.pdb *pdb.gz *.ent *.ent.gz)', ]) fpath = filedialog.get_open_file_name(caption='Load Structure', filter=filter_str) if fpath: try: self._structure = structure.Structure.read(fpath) fname = os.path.basename(fpath) self.ui.loaded_structure_lbl.setText(fname) except: self._structure = None self.ui.loaded_structure_lbl.clear() self.warning('Could not load input structure.') def _onLoadFromPTBtnClicked(self): try: pt = maestro.project_table_get() rows = pt.selected_rows sts = [entry.getStructure() for entry in rows] if len(sts) != 1: self.warning('More than one row selected.') return st = sts[0] self._structure = st self.ui.loaded_structure_lbl.setText(st.title) except: self._structure = None self.ui.loaded_structure_lbl.clear() self.warning('Could not load input structure.') def _onTrajectoryChange(self, max_frames=1): self.reset() self.ui.frame_sb.setRange(1, max_frames)
[docs] def getSelectedAsl(self): aids = maestro.selected_atoms_get() if aids: asl_text = "atom.n " + ",".join(map(str, aids)) else: asl_text = '' return asl_text
def _onLoadSelectionBtnClicked(self): asl_text = self.getSelectedAsl() self.atom_selector.setAsl(asl_text)
[docs] def reset(self): """ Resets the RMSD settings to taking data from frame 1 """ self._structure = None self.ui.loaded_structure_lbl.clear() self.ui.frame_rb.setChecked(True) self.ui.frame_sb.setValue(1) self.atom_selector.setAsl('all') self._updateReferenceComponents()
def _calculateFitAttributes(self, entry_traj): """ Calculates fit aids and fit reference position using RMSD dialog values :param fit_asl: ASL to fit to :type fit_asl: str :param entry_traj: CMS model :type entry_traj: :return: fit aids, and fit ref pos :rtype: list(int), list(int) Raises RuntimeError on failure. """ cms_model = entry_traj.cms_model fit_asl = self.getSuperimposeAsl() if not fit_asl: fit_asl = self.getSelectedAsl() if not analyze.validate_asl(fit_asl): raise RuntimeError(f'Invalid RMSD asl selection: {fit_asl}') fit_aids = None fit_ref_pos = None # Calculate fit arguments if we're not superimposing over the current selection if fit_asl != self.getSelectedAsl(): fit_aids = cms_model.select_atom(fit_asl) if len(fit_aids) == 0: raise RuntimeError( f'No valid atoms found for superimpose asl "{fit_asl}" on the ' 'given structure. Plot could not be created.') if self.ui.frame_rb.isChecked(): # Backend frame is 0-based, so offset by 1 fr_idx = self.getFrameNumber() - 1 frame = entry_traj.trajectory[fr_idx] topo.update_fsys_ct_from_frame_GF(cms_model.fsys_ct, cms_model, frame) fit_ref_pos = [cms_model.atom[aid].xyz for aid in fit_aids] else: dlg_st = self.getStructure() dlg_st_aids = analyze.evaluate_asl(dlg_st, fit_asl) fit_ref_pos = [dlg_st.atom[aid].xyz for aid in dlg_st_aids] if len(fit_ref_pos) != len(fit_aids): raise RuntimeError( f'Unequal number of atoms in superimposed reference structure ({len(fit_ref_pos)}) ' f'and superimposed workspace selection ({len(fit_aids)})' ) return fit_aids, fit_ref_pos