Source code for schrodinger.application.steps.basesteps

from schrodinger import stepper
from schrodinger.utils import log

from .dataclasses import FileInMixin
from .dataclasses import MaeMaeMixin
from .dataclasses import MaeOutMixin
from .dataclasses import MolInMixin
from .dataclasses import MolMolMixin
from .dataclasses import MolOutMixin
from .utils import to_string


[docs]class LoggerMixin: """ A mixin for stepper._BaseStep to allow debugging information about every input and output that is processed """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.logger = log.get_output_logger(self.getStepId())
[docs] def reduceFunction(self, inputs): for inp in inputs: self.logger.debug(f'{self.getStepId()} <- {to_string(inp)}') for output in self.mapFunction(inp): self.logger.debug(f'{self.getStepId()} -> {to_string(output)}') yield output
[docs]class MolReduceStep(MolMolMixin, stepper.ReduceStep): pass
[docs]class MolMapStep(LoggerMixin, MolReduceStep): pass
[docs]class MaeReduceStep(MaeMaeMixin, stepper.ReduceStep): pass
[docs]class MaeMapStep(LoggerMixin, MaeReduceStep): pass
[docs]class OptionalStepsMixin: """ A mixin that provides optional steps for a `Workflow` class. Whether the step is optional is determined by the `OPTIONAL_SETTINGS_MAP`, a dictionary with step classes as the key and the bool settings attributes as the value. E.g.:: class ExampleOptionalStepChain(OptionalStepMixin, MolMolWorkflow): STEPS = (filters.UniqueSmilesFilter, filters.SmartsFilter, filters.PropertyFilter) OPTIONAL_SETTINGS_MAP = {filters.SmartsFilter: 'smarts_filter'} class Settings(parameters.CompoundParam): smarts_filter: bool = False """ OPTIONAL_SETTINGS_MAP = {}
[docs] def buildChain(self): for step in self.STEPS: use_step = True if step in self.OPTIONAL_SETTINGS_MAP: use_step = self.settings.getSubParam( self.OPTIONAL_SETTINGS_MAP[step]) if use_step: self.addStep(step())
[docs]class Workflow(stepper.Chain): """ A chain of steps. To define which steps should be present, populate the `STEPS` variable. """ STEPS = ()
[docs] def buildChain(self): for step in self.STEPS: self.addStep(step())
[docs]class MolMolWorkflow(MolMolMixin, Workflow): pass
[docs]class MolMaeWorkflow(MolInMixin, MaeOutMixin, Workflow): pass
[docs]class MaeMaeWorkflow(MaeMaeMixin, Workflow): pass
[docs]class FileMolWorkflow(FileInMixin, MolOutMixin, Workflow): pass
# ============================================================================== # DEDUPLICATION BASE STEPS # ==============================================================================
[docs]class CloudFilterChain(stepper.Chain): """ Generic filter chain to use with cloud databases. Note that this class needs to be inherited with the following attributes or methods defined: 1. `self.Settings` 2. `self.buildChain()` 3. `self._setUpTable()` 4. `self._validateTable()` """
[docs] def setUp(self): self._setUpTable()
def _setUpTable(self): raise NotImplementedError
[docs] def validateSettings(self): ret = super().validateSettings() return ret + self._validateTable()
def _validateTable(self): raise NotImplementedError
# Sentinel representing default batch settings for upload and download step. _DEFAULT_BATCH_SETTINGS = object()
[docs]class UploadStep(MolInMixin, stepper.ReduceStep):
[docs] def __init__(self, *args, batch_settings=_DEFAULT_BATCH_SETTINGS, **kwargs): if batch_settings is _DEFAULT_BATCH_SETTINGS: batch_settings = stepper.BatchSettings(use_pubsub=True) super().__init__(*args, batch_settings=batch_settings, **kwargs)
[docs]class DownloadStep(MolOutMixin, stepper.MapStep):
[docs] def __init__(self, *args, batch_settings=_DEFAULT_BATCH_SETTINGS, **kwargs): if batch_settings is _DEFAULT_BATCH_SETTINGS: batch_settings = stepper.BatchSettings(use_pubsub=True, size=5) super().__init__(*args, batch_settings=batch_settings, **kwargs)
[docs]class TableReduceStep(stepper.ReduceStep):
[docs] def reduceFunction(self, inps): tids_to_tables = {inp.getFullTableId(): inp for inp in inps} for table_id in tids_to_tables.keys(): self._actOnTable(table_id) yield from tids_to_tables.values()
def _actOnTable(self, table_id): raise NotImplementedError