Source code for schrodinger.maestro_utils.maestro_sync

from contextlib import contextmanager

from decorator import decorator

from schrodinger import get_maestro
from schrodinger.models import diffy
from schrodinger.models import parameters
from schrodinger.models.mappers import TargetMixin
from schrodinger.project import utils as proj_utils
from schrodinger.project.manager import EntryGroupManager
from schrodinger.Qt.QtCore import QObject
from schrodinger.Qt.QtCore import pyqtSignal
from schrodinger.ui.qt.appframework2.maestro_callback import CALLBACK_FUNCTIONS
from schrodinger.ui.qt.decorators import suppress_signals
from schrodinger.utils.scollections import IdSet

maestro = get_maestro()

PROPNAME_EID = 's_m_entry_id'
NO_GETTER_MSG = ('You must define a structure getter function using'
                 ' setStructureGetter().')


[docs]class BaseMaestroSync: """ Basic Maestro synchronization class that can be used as infrastructure for responding to specific events with the Maestro project and Workspace. :ivar _cb_type_callback_map: a dictionary mapping callback types to the set of callback functions associated with that type :vartype _cb_type_callback_map: set[str, set[Callable]] :ivar _callbacks_active: whether the stored callbacks should be registered at this time :vartype _callbacks_active: bool """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._cb_type_callback_map = { cb_type: set() for cb_type in CALLBACK_FUNCTIONS.keys() } self._callbacks_active = False self.addProjectUpdateCallback(self.onProjectUpdated) self.addProjectCloseCallback(self.onProjectClosed) self.addWorkspaceChangeCallback(self.onWorkspaceChanged)
# NOTE: Not defining or adding `self.onHoverChanged` due to MAE-45464
[docs] def addProjectUpdateCallback(self, callback_fn): """ Add function to list of "project updated" callbacks. :param callback_fn: the function to add :type callback_fn: Callable """ self._addCallback(maestro.PROJECT_UPDATE_CALLBACK, callback_fn)
[docs] def addProjectCloseCallback(self, callback_fn): """ Add function to list of "project closed" callbacks. :param callback_fn: the function to add :type callback_fn: Callable """ self._addCallback(maestro.PROJECT_CLOSE_CALLBACK, callback_fn)
[docs] def addWorkspaceChangeCallback(self, callback_fn): """ Add function to list of "workspace changed" callbacks. :param callback_fn: the function to add :type callback_fn: Callable """ self._addCallback(maestro.WORKSPACE_CHANGED_CALLBACK, callback_fn)
[docs] def addHoverCallback(self, callback_fn): """ Add function to list of "hover" callbacks. :param callback_fn: the function to add. This function should expect to receive a single parameter of type `int`. The parameter represents the number of the atom which the mouse is currently resting over (or an invalid index if 0---see mm.mmct_valid_atom) :type callback_fn: Callable """ self._addCallback(maestro.HOVER_CALLBACK, callback_fn)
def _addCallback(self, cb_type, callback_fn): """ Add a callback to the specified set of callbacks. If callbacks are active, immediately register it. :param cb_type: the type of Maestro callback to associate with the supplied callback function :type cb_type: str :param callback_fn: the function to associate with a Maestro callback :type callback_fn: Callable """ self._cb_type_callback_map[cb_type].add(callback_fn) if self._callbacks_active: self._registerCallbacks()
[docs] def setCallbacksActive(self, active): """ Register or deregister all callbacks. When activating, call all "project updated" callbacks. :param active: whether to register or deregister callbacks with Maestro :type active: bool """ self._callbacks_active = active if active: self._registerCallbacks() else: self._deregisterCallbacks()
[docs] @contextmanager def suspendCallbacks(self): """ Context manager to temporarily disable Maestro callbacks. """ init_active = self._callbacks_active self.setCallbacksActive(False) try: yield finally: self.setCallbacksActive(init_active)
def _registerCallbacks(self): """ Register all stored Maestro callbacks. """ for cb_type, cb_fns in self._cb_type_callback_map.items(): cb_info = CALLBACK_FUNCTIONS[cb_type] for cb_fn in cb_fns: if (not cb_info.maestro_check_callback or not maestro.is_function_registered(cb_type, cb_fn)): cb_info.add(cb_fn) # For certain callback types, execute the callback as # soon as it is registered to ensure synchronization if cb_type == maestro.PROJECT_UPDATE_CALLBACK: cb_fn() elif cb_type == maestro.WORKSPACE_CHANGED_CALLBACK: cb_fn(maestro.WORKSPACE_CHANGED_EVERYTHING) def _deregisterCallbacks(self): """ Deregister all stored Maestro callbacks. """ for cb_type, cb_fns in self._cb_type_callback_map.items(): cb_info = CALLBACK_FUNCTIONS[cb_type] for cb_fn in cb_fns: if (not cb_info.maestro_check_callback or maestro.is_function_registered(cb_type, cb_fn)): cb_info.remove(cb_fn)
[docs] def onProjectUpdated(self): """ Callback method for project update events. Should be overridden in concrete subclasses to add functionality. """ pass
[docs] def onProjectClosed(self): """ Callback method for project close events. Should be overridden in concrete subclasses to add functionality. """ pass
[docs] def onWorkspaceChanged(self, what_changed): """ Callback method for workspace change events. Should be overridden in concrete subclasses to add functionality. :param what_changed: the kind of change that occurred in the Workspace; will be one of the `WORKSPACE_CHANGED_` constants in maestro.py :type what_changed: str """ pass
[docs]@decorator def requires_structure_getter(method, self, *args, **kwargs): """ Decorator for `ProjectEntryMaestroSync` methods that raises an exception if the decorated method is called when the structure getter is undefined. """ if self._structure_getter is None: raise RuntimeError(NO_GETTER_MSG) return method(self, *args, **kwargs)
[docs]class ProjectEntryMaestroSync(BaseMaestroSync, QObject): """ Maestro sync class that acts as an interface with project entries. If the user has a `ParamListParam` (PLP) or multiple PLPs that they wish to have correspond with entries in the project, they can assign them as the models for one or more of the targets associated with this class: 1. `entry_plp_target`: creates and tracks entries associated with items in the model PLP. If the user wishes to use any of the subsequent targets, they must first assign a model for this target. 2. `group_plp_target`: moves associated entries into an entry group 3. `select_plp_target`: selects associated entries 4. `include_plp_target`: includes associated entries In addition to assigning model PLPs to the above targets, the user must also assign a function that can return a structure from a PLP item using the `setStructureGetter()` method. A simple example would just be:: mae_sync = ProjectEntryMaestroSync() mae_sync.setStructureGetter(lambda item: item.structure) Finally, the user may also assign a `StringParam` model that tracks the title for the entry group created for the `group_plp_target` model. :ivar _structure_getter: the function used to access structure objects from PLP items :vartype _structure_getter: callable :ivar _st_eid_map: a dictionary mapping structures to their corresponding entry IDs :vartype _st_eid_map: dict[structure.Structure, str] """
[docs] def __init__(self, parent=None, group_name=None, parent_group_name=None): """ :param parent: the parent for this object :type parent: QtCore.QObject :param group_name: a custom name for the entry group tracked by `group_plp_target` :type group_name: str or NoneType :param parent_group_name: the parent group name for the entry group tracked by `group_plp_target`, if it is meant to be a subgroup :type parent_group_name: str """ super().__init__(parent=parent) self._structure_getter = None self._st_eid_map = {} self._group_manager = EntryGroupManager( group_name=group_name, parent_group_name=parent_group_name) self.title_target = TitleTarget(self._group_manager) self.entry_plp_target = PLPTarget() self.entry_plp_target.PLPMutated.connect(self._onEntryPLPMutated) self.group_plp_target = PLPTarget() self.group_plp_target.PLPMutated.connect(self._onGroupPLPMutated) self.select_plp_target = PLPTarget() self.select_plp_target.PLPMutated.connect(self._onSelectPLPMutated) self.include_plp_target = PLPTarget() self.include_plp_target.PLPMutated.connect(self._onIncludePLPMutated)
[docs] def setStructureGetter(self, getter): """ Assign the function that can be used by the PLP targets on this class to retrieve structure objects from items on the PLP, e.g. :: plp = self.entry_plp_target.getPLP() for item in plp: st = self._structure_getter(item) This getter must be set in order for this class to function. :param getter: a function that will return a structure object from a PLP item :type getter: callable """ self._structure_getter = getter
[docs] def groupName(self): """ :return: the entry group name associated with this class :rtype: str """ return self._group_manager.name
@requires_structure_getter def _onEntryPLPMutated(self, new_plp, old_plp): """ Respond to the standard entry PLP changing by adding or removing entries from the project. :param new_plp: the new value of the model PLP :type new_plp: parameters.ParamListParam :param old_plp: the old value of the model PLP :type old_plp: parameters.ParamListParam """ pt = proj_utils.get_PT() plp_diff = diffy.get_diff_list(new_plp, old_plp) eids_to_remove = set() for item, _ in plp_diff.removed: st = self._structure_getter(item) entry_id = self._st_eid_map.get(st) if entry_id is not None: eids_to_remove.add(entry_id) self._st_eid_map.pop(st, None) if eids_to_remove: with self.suspendCallbacks(): proj_utils.remove_entries(eids_to_remove) # Re-order the "added" items so that they match the order in `new_plp` added_items = IdSet(item for item, _ in plp_diff.added) added_items = [item for item in new_plp if item in added_items] with self.suspendCallbacks(): for item in added_items: st = self._structure_getter(item) # Check whether this is a known structure with a valid entry entry_id = self._st_eid_map.get(st) if entry_id is not None and pt.getRow(entry_id) is not None: row = pt.getRow(entry_id) else: row = pt.importStructure(st) self._st_eid_map[st] = row.entry_id if plp_diff.added: self._addEntriesToGroup() self._updateEntrySelection() self._updateEntryInclusion() @requires_structure_getter def _onGroupPLPMutated(self, new_plp, old_plp): """ Respond to the group PLP changing by adding or removing entries from the entry group. :param new_plp: the new value of the model PLP :type new_plp: parameters.ParamListParam :param old_plp: the old value of the model PLP :type old_plp: parameters.ParamListParam """ removed = diffy.get_removed_list(new_plp, old_plp) rows_to_move_out = [] for item, _ in removed: st = self._structure_getter(item) entry_id = self._st_eid_map.get(st) if entry_id is not None and proj_utils.get_row(entry_id): rows_to_move_out.append(proj_utils.get_row(entry_id)) with self.suspendCallbacks(): for row in rows_to_move_out: row.ungroup() self._addEntriesToGroup() def _onSelectPLPMutated(self, new_plp, old_plp): """ Respond to the selection PLP changing by selecting only entries associated with items from the selection PLP. :param new_plp: the new value of the model PLP :type new_plp: parameters.ParamListParam :param old_plp: the old value of the model PLP :type old_plp: parameters.ParamListParam """ self._updateEntrySelection() def _onIncludePLPMutated(self, new_plp, old_plp): """ Respond to the inclusion PLP changing by including only entries associated with items from the inclusion PLP. :param new_plp: the new value of the model PLP :type new_plp: parameters.ParamListParam :param old_plp: the old value of the model PLP :type old_plp: parameters.ParamListParam """ self._updateEntryInclusion() @requires_structure_getter def _getEntryIDsFromPLP(self, plp): """ Return a list of entry IDs associated with each item of the PLP for which one can be found. If no entry ID can be found for one of the items in the PLP, skip it. :raises RuntimeError: if no structure getter is defined :param plp: the old value of the model PLP :type plp: parameters.ParamListParam """ pt = proj_utils.get_PT() entry_ids = [] for item in plp: st = self._structure_getter(item) entry_id = self._st_eid_map.get(st) if entry_id is not None: entry_ids.append(entry_id) return entry_ids def _addEntriesToGroup(self): """ Search the entry group PLP for new items. If any can be found, add their associated entries to the entry group. """ plp = self.group_plp_target.getPLP() if plp is None: return rows_to_move_in = [] for entry_id in self._getEntryIDsFromPLP(plp): row = proj_utils.get_row(entry_id) if row and not row.group or row.group.name != self._group_manager.name: rows_to_move_in.append(row) if rows_to_move_in: with self.suspendCallbacks(): for row in rows_to_move_in: row.moveToGroup(self._group_manager.name) def _updateEntrySelection(self): """ Select only entries associated with items from the selection PLP. """ plp = self.select_plp_target.getPLP() if plp is None: return entry_ids = self._getEntryIDsFromPLP(plp) pt = proj_utils.get_PT() with self.suspendCallbacks(): pt.selectRows(entry_ids=entry_ids) def _updateEntryInclusion(self): """ Include only entries associated with items from the selection PLP. """ plp = self.include_plp_target.getPLP() if plp is None: return entry_ids = self._getEntryIDsFromPLP(plp) pt = proj_utils.get_PT() with self.suspendCallbacks(): pt.includeRows(entry_ids=entry_ids)
[docs] def onProjectUpdated(self): """ Respond to the project updating by searching for entries associated with the project entry PLP. If any have been deleted from the project, remove the associated model PLP items. """ plp = self.entry_plp_target.getPLP() if plp is None: return if self._structure_getter is None: raise RuntimeError(NO_GETTER_MSG) pt = proj_utils.get_PT() all_entry_ids = {row.entry_id for row in pt.all_rows} removed_sts = set() items_to_remove = [] for item in plp: st = self._structure_getter(item) entry_id = self._st_eid_map.get(st) if entry_id not in all_entry_ids: items_to_remove.append(item) removed_sts.add(st) with suppress_signals(plp): for item in items_to_remove: plp.remove(item) for st in removed_sts: entry_id = self._st_eid_map.pop(st, None) if removed_sts: plp.emitMutated()
[docs]class PLPTarget(TargetMixin, QObject): """ A generic target for a PLP model. :ivar PLPMutated: a signal that propagates the `ParamListParam.mutated` signal from the model PLP :vartype PLPMutated: QtCore.pyqtSignal """ PLPMutated = pyqtSignal(object, object)
[docs] def __init__(self, parent=None): super().__init__(parent=parent) self._plp = None
[docs] def targetGetValue(self): # See `mappers.TargetMixin` for full documentation. return self.getPLP()
[docs] def targetSetValue(self, value): # See `mappers.TargetMixin` for full documentation. self.setPLP(value)
[docs] def getPLP(self): """ :return: the PLP model for this target, if one has been defined :rtype: parameters.ParamListParam or NoneType """ return self._plp
[docs] def setPLP(self, plp): """ Assign a new PLP model. :raises TypeError: if an invalid parameter is supplied :param plp: a PLP containing structures, or None :type plp: parameters.ParamListParam or NoneType """ if isinstance(plp, parameters.DictParam): msg = (f'{type(self).__name__}.setPLP() must be called with' ' a concrete PLP, not an abstract PLP.') raise TypeError(msg) if type(plp).__name__ != 'ListWithSignal': msg = (f'{type(self).__name__}.setPLP() expects a concrete PLP.' f' Instead, got an argument of type {type(plp)}.') raise TypeError(msg) if self._plp is not None: for signal, slot in self._getSignalsAndSlots(self._plp): signal.disconnect(slot) self._plp = plp if self._plp is not None: for signal, slot in self._getSignalsAndSlots(self._plp): signal.connect(slot)
def _getSignalsAndSlots(self, plp): return [ (plp.mutated, self.PLPMutated) ] # yapf: disable
[docs]class TitleTarget(TargetMixin, QObject): """ Target for managing the title of an entry group via a group manager. """
[docs] def __init__(self, group_manager, parent=None): super().__init__(parent=parent) self._group_manager = group_manager
[docs] def targetGetValue(self): """ See `mappers.TargetMixin` for full documentation. :return: the current group title :rtype: str """ return self._group_manager.title
[docs] def targetSetValue(self, value): """ See `mappers.TargetMixin` for full documentation. :param value: a new group title :type value: str """ self._group_manager.title = value