Source code for schrodinger.stepper.sideinputs

"""
Utility steps for creating side inputs in stepper workflows.

For example, in a workflow with steps A, B, C, and D, a ForkStep and JoinStep
can be set up so that all outputs from A are passed along to D. This allows
outputs from A to get to D even if B or C would normally filter those inputs.

Example::

    class MyWorkflow(stepper.Chain):
        def buildChain(self):
            a = A()
            self.addStep(a)
            fork = ForkStep(step=a)
            self.addStep(fork)
            self.addStep(B())
            self.addStep(C())
            self.addStep(JoinStep(fork=fork))
            self.addStep(D())

"""
import os
import uuid

from schrodinger.models import parameters
from schrodinger import stepper

SCHRODINGER_RUN = os.path.join(os.environ['SCHRODINGER'], 'run')


[docs]class ForkStep(stepper.UnbatchedReduceStep): """ A step to save some inputs to be reprocessed again. See the module docstring for more info and an example. """
[docs] def __init__(self, step): self.Input = self.Output = step.Output self.InputSerializer = self.OutputSerializer = step.OutputSerializer self._pipe_fname = f'.{str(uuid.uuid4())}.forkfile' super().__init__()
[docs] def reduceFunction(self, inps): serializer = self.getOutputSerializer() with open(self._pipe_fname, 'w') as outfile: for inp in inps: outfile.write(f"{serializer.toString(inp)}\n") yield inp
[docs] def getPipeFilename(self): return self._pipe_fname
[docs] def report(self, prefix=''): stepper.logger.info(f'{prefix} - {self.getStepId()}')
class JoinStep(stepper.UnbatchedReduceStep): """ A step to read some inputs saved by a preceding ForkStep. See the module docstring for more info and an example. """ def __init__(self, fork): self._fork = fork self.InputSerializer = self.OutputSerializer = fork.OutputSerializer self._in_fname = fork.getPipeFilename() super().__init__() @property def Input(self): return self._fork.Input @property def Output(self): return self._fork.Output
[docs] def reduceFunction(self, inps): yield from inps yield from self.getOutputSerializer().deserialize(self._in_fname)
[docs] def report(self, prefix=''): stepper.logger.info( f'{prefix} - {self.getStepId()} <- {self._fork.getStepId()}')
class JoinFromFileStep(stepper.UnbatchedReduceStep): """ A step for injecting inputs read from a file into a chain. To use, add into your chain and set the step's `join_file` setting to the path of your datafile. """
[docs] class Settings(parameters.CompoundParam): join_file: stepper.StepperFile = None
def __init__(self, Input=None, InputSerializer=None, **kwargs): if Input is None and InputSerializer is None: raise TypeError("Must set either Input or InputSerializer at " "step initialization time.") elif Input is not None and InputSerializer is not None: raise TypeError("Can't set both Input _and_ InputSerializer") if InputSerializer: Input = InputSerializer.DataType self.InputSerializer = self.OutputSerializer = InputSerializer self.Input = self.Output = Input super().__init__(**kwargs)
[docs] def reduceFunction(self, inps): serializer = self._getInputSerializer() yield from serializer.deserialize(self.settings.join_file) yield from inps
#============================================================================== # PUBSUB FUNCTIONALITY # # The below code adds functionality to the Join* so extra inputs will just # be added to the input topics. This optimizes the Join* steps so they don't # unnecessarily read topics just to append a few extra inputs. # # NOTE: This section is designed to be 'transparent', meaning if a chain # is not using pubsub, all JoinSteps will behave normally. Additionally, # if this section of code is removed, the base functionality of the Join* # will still work. #==============================================================================
[docs]class JoinStep(stepper.PubsubEnabledStepMixin, JoinStep):
[docs] def outputs(self): if self.usingPubsub(): inp_topic = self.getInputTopic() self.setOutputTopic(inp_topic) extra_outputs = self.getOutputSerializer().deserialize( self._in_fname) self._uploadToTopic(extra_outputs, self.getOutputSerializer(), inp_topic) return self._deserializeFromOutputTopic() else: return super().outputs()
[docs] def usingPubsub(self): return bool(self.getInputTopic())
[docs]class JoinFromFileStep(stepper.PubsubEnabledStepMixin, JoinFromFileStep):
[docs] def outputs(self): if self.usingPubsub(): inp_topic = self.getInputTopic() self.setOutputTopic(inp_topic) extra_outputs = self.getOutputSerializer().deserialize( self.settings.join_file) self._uploadToTopic(extra_outputs, self.getOutputSerializer(), inp_topic) return self._deserializeFromOutputTopic() else: return super().outputs()
[docs] def usingPubsub(self): return bool(self.getInputTopic())