Source code for schrodinger.tasks.tasks

"""
A task represents a block of work that has a defined input and output and runs
without user intervention. Different task classes share a common external API
but have different implementations for defining and executing the work, such as
blocking calls, threads, subprocesses, or job control (see jobtasks).

To define a task, follow these basic instructions:

1. Choose a task class to subclass. The choice of task class is primarily
dictated by how the task needs to run - thread, subprocess, job, etc. See the
Task Class Selection Guide for help.

2. Override the input and output params. The task.input and task.output params
may be of any Param type, including CompoundParam (typical). For CompoundParams,
either use an existing class to override task.input, OR define a nested class
named Input within the task. Doing so will automatically override task.input.
The same goes for task.output. Example::

    class FooTask(tasks.ThreadFunctionTask):
        input = AtomPair()  # AtomPair is an existing CompoundParam subclass

        # This will magically override FooTask.output = Output()
        class Output(parameters.CompoundParam):
            charge: float
            processed_atom_pair: AtomPair

3. Define the work of the task. This is done differently for different task
classes, but generally involves overriding a method to either provide python
logic directly as the work to be done or to construct a command line with the
appropriate arguments that will be invoked via the appropriate mechanism for the
task type.

Once a task is defined, it can be instantiated, set up, and started::

    task = FooThreadTask()
    task.input.x = 3
    task.input.y = 4
    task.start()
    assert task.status is tasks.Status.RUNNING
    task.wait()
    assert task.status is tasks.Status.DONE
    print(task.output)

.. warning::
    `wait()` executes a local event loop, so it should not be called directly
    from a GUI - see PANEL-18317 for discussion. `wait()` is safe to call
    inside a subprocess or job (e.g. if a jobtask spawns child tasks).
    Run `git grep "task[.]wait("` to see safe examples annotated with "# OK".

==================
Pre/postprocessors
==================

Tasks support pre/post processing functions. These can either be methods in the
class that are decorated with the preprocessor or postprocessor decorators, or
external functions that are added to a task instance. Example::

    class MyTask(tasks.BlockingFunctionTask):
        @tasks.preprocessor
        def checkInput(self):
            if self.input.x <0:
                return False, 'x must be a nonnegative number.'

For more information, see the module-level preprocessor and postprocessor
decorators as well as the start(), preprocessors(), and addPreprocessor()
methods of AbstractTask.

========================
Task directory (taskdir)
========================

Tasks have a concept of a taskdir. While the task framework will never actually
chdir into a different directory, the task provides functions for specifying
and accessing a directory that is considered that task's directory by
convention. Subprocesses started by the task will use the taskdir as their
working directory.

To specify a taskdir, override AbstractTask.DEFAULT_TASKDIR_SETTING or use
task.specifyTaskDir(). Example::

    class MyTask(tasks.BlockingFunctionTask):
        DEFAULT_TASKDIR_SETTING = tasks.AUTO_TASKDIR

    task = MyTask()
    task.specifyTaskDir('foo_dir')

The taskdir is created during preprocessing. Once the taskdir is created, use
task.getTaskDir() and task.getTaskFilename() when reading and writing files for
the task. Example::

    class MyTask(tasks.SubprocessCmdTask):
        @tasks.preprocessor(order=tasks.AFTER_TASKDIR)
        def writeInputFiles(self):
            with open(self.getTaskFilename('foo_data.txt'), 'w') as f:
                f.write(self.input.foo_data)

For more details on taskdir, see task.specifyTaskDir() task.getTaskDir().

==========================
Input/Output File Handling
==========================

To specify a task input file or folder, use the `TaskFile` or `TaskFolder`
classes as a subparam on the task.input param. If the task runs its unit
of work on a different machine or process, the input files/folders will
automatically be copied to the right location on the compute host. The
path to the `TaskFile`/`TaskFolder` will also be updated so it points
to the right location, regardless of when or where it's accessed.

`TaskFile`/`TaskFolder`s may be nested under the input param in supported
container types. Supported container types are:
There are few restrictions on how nested you can define your
`TaskFile/TaskFolder` on the input param. For example, if you have
a variable number of input files, you can define the input with a list::

    - List
    - Dict
    - Set
    - Tuple
    - CompoundParam

For example::

    class Input(parameters.CompoundParam):
        receptor_filename: TaskFile
        ligand_filenames: List[TaskFile]

Task output files/folders behave in the exact same way as task input
files/folders except they're defined as `TaskFile` or `TaskFolder` on the
output param.
"""
import contextlib
import copy
import enum
import inspect
import os
import pathlib
import pickle
import random
import shutil
import string
import sys
import tempfile
import traceback
import typing
from collections import namedtuple
from datetime import datetime
from typing import List

from schrodinger.models import json
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.models import paramtools
from schrodinger.Qt import QtCore
from schrodinger.Qt.QtCore import QProcess
from schrodinger.tasks import cmdline
from schrodinger.ui.qt.appframework2 import application
from schrodinger.utils import fileutils
from schrodinger.utils import funcchains
from schrodinger.utils import imputils
from schrodinger.utils import qt_utils
from schrodinger.utils import scollections
from schrodinger.utils import subprocess as subprocess_utils


[docs]class TaskDirNotFoundError(RuntimeError): pass
[docs]class TaskFile(str): """ See the "Input/Output File Handling" section of the module docstring for information. """
[docs]class TaskFolder(str): """ See the "Input/Output File Handling" section of the module docstring for information. """
#=============================================================================== # Task pre/post processing #=============================================================================== # Ordering constants BEFORE_TASKDIR = -2000 # Runs preprocessor before taskdir creation AFTER_TASKDIR = 0 # Runs preprocesser after taskdir creation (default) _TASKDIR_ORDER = -1000 # Order for taskdir creation _WRITE_JSON_ORDER = 10000
[docs]class TaskDirSetting(enum.Enum): AUTO_TASKDIR = enum.auto() TEMP_TASKDIR = enum.auto()
# Taskdir settings AUTO_TASKDIR = TaskDirSetting.AUTO_TASKDIR TEMP_TASKDIR = TaskDirSetting.TEMP_TASKDIR class _ProcessorMarker(funcchains.FuncChainMarker): def customizeFuncResult(self, func, result): return _cast_processing_result(result, func) """ The preprocessor and post processor decorators can be used to mark functions to be run before/after a task. These decorators may be used on task methods both with or without args:: class MyTask(tasks.BlockingFunctionTask): @tasks.preprocessor # Use without args def checkInput(self): pass @tasks.preprocessor(order=tasks.AFTER_TASKDIR) # Use with args def writeInput(self): pass The optional order argument is a float that is used as a sorting key to determine the order of execution of pre/postprocessors. It's recommended that one of the module level ordering constants is used, with +/- increments to fine- tune the order. For example:: class MyTask(tasks.BlockingFunctionTask): @tasks.preprocessor(order=tasks.AFTER_TASKDIR) def checkInput(self): pass def writeInput(self, order=tasks.AFTER_TASKDIR+1): pass External functions may also be decorated. In this case, the function must also be added to a task instance. Example:: @tasks.preprocessor(order=tasks.AFTER_TASKDIR) def foo() pass task = MyTask() task.addPreprocessor(foo) Pre/postprocessors may optionally return a ProcessingResult. As a convenience, a (passed, message) tuple return value will automatically be cast into a ProcessingResult by the decorator. Examples:: @tasks.preprocessor def checkInput(self): if self.input.x < 0: # Preprocessing failure return False, 'x must be nonnegative.' if self.input.x > 100: # Preprocessing warning return True, 'Large values of x may take a long time.' return True # Pass (equivalent to returning None) Returning False without a message will be a silent failure. """ preprocessor = _ProcessorMarker('preprocessor') postprocessor = _ProcessorMarker('postprocessor')
[docs]class ProcessingResult: """ A general-purpose return value for task pre/post processors """
[docs] def __init__(self, passed, message=None): """ :param passed: Whether the result is considered to be passing :type passed: bool :param message: A message for this result :type message: str """ self.func = None self.passed = passed self.message = message
[docs] def processorName(self): if self.func is None: return return self.func.__name__
def __bool__(self): return self.passed def __repr__(self): return str(self) def __str__(self): msg = '' if self.func is not None: msg += f'{self.func.__name__}: ' if self.passed and not self.message: msg += 'Passed' elif self.passed and self.message: msg += f'WARNING - {self.message}' elif not self.passed and not self.message: msg += 'FAILED' elif not self.passed and self.message: msg += f'FAILED - {self.message}' return msg
[docs]class CallingContext(enum.IntEnum): CMDLINE = enum.auto() GUI = enum.auto()
def _cast_processing_result(result, func=None): """ Convert the return value of a pre/post-processor to a ProcessingResult, if necessary. If a func is supplied, it will be recorded in the ProcessingResult. :param func: the function that produced this result :param result: the return value of the pre/post-processor. This can be represented in one of three ways: (1) True/False for passsed (2) tuple of (passed, message) (3) a ProcessingResult instance :type result: bool, tuple, or ProcessingResult :return: the wrapped return value :rtype: ProcessingResult """ if result is None: result = True if isinstance(result, bool): result = ProcessingResult(result) if isinstance(result, tuple): result = ProcessingResult(*result) if isinstance(result, ProcessingResult): result.func = func return result raise TypeError(f'Return value should be bool or tuple. Got {result}') #=============================================================================== # Task exceptions #===============================================================================
[docs]class TaskFailure(Exception): """ Exception raised when a task fails for reasons other than an unexpected error occuring during execution. """
# This class intentionally left blank.
[docs]class TaskKilled(TaskFailure): pass
class _TaskTestTimeout(TaskFailure): """ Exception raised if a task times out under pytest """ pass #=============================================================================== # Status #=============================================================================== FailureInfo = namedtuple('FailureInfo', 'exception traceback message')
[docs]class FailureInfo(FailureInfo): def __str__(self): if self.exception is None: return 'No failure recorded.' else: return f'Task failure:\n{self.traceback}\n{self.exception}'
[docs]class Status(jsonable.JsonableIntEnum): WAITING, RUNNING, FAILED, DONE = range(4)
FINISHED_STATUSES = {Status.FAILED, Status.DONE} STARTABLE_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE} NON_RUNNING_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE} def _wait(task, timeout=None): """ Block until the task is finished executing or `timeout` seconds have passed. :param timeout: Amount of time in seconds to wait before timing out. If None or a negative number, this method will wait until the task is finished. :type timeout: NoneType or int :return: whether the task finished during the wait. Returns False if wait timed out """ return _wait_for(task, NON_RUNNING_STATUSES, timeout=timeout) @application.require_application(use_qtcore_app=True) def _wait_for(task, end_statuses, timeout=None): """ Block until a task reaches one of the specified statuses. Blocks using a local event loop. :param task: the task to wait on :param end_statuses: the task statuses to wait for :param timeout: an optional timeout in seconds :return: whether the wait succeeded. Returns False if wait timed out """ if task.status in end_statuses: return True event_loop = QtCore.QEventLoop() def check_status(status): if status in end_statuses: event_loop.exit() def time_out_event_loop(): event_loop.exit() if timeout is not None: QtCore.QTimer.singleShot(timeout * 1000, time_out_event_loop) task.statusChanged.connect(check_status) event_loop.exec() return task.status in end_statuses #=============================================================================== # Abstract Task #===============================================================================
[docs]@qt_utils.add_enums_as_attributes(Status) @qt_utils.add_enums_as_attributes(CallingContext) class AbstractTask(funcchains.FuncChainMixin, parameters.CompoundParam): input: parameters.CompoundParam output: parameters.CompoundParam status: Status name: str progress: int max_progress: int progress_string: str calling_context = parameters.NonParamAttribute() failure_info = parameters.NonParamAttribute() # Convenience Signals taskDone = QtCore.pyqtSignal() taskStarted = QtCore.pyqtSignal() taskFailed = QtCore.pyqtSignal() DEFAULT_TASKDIR_SETTING = None AUTO_TASKDIR = AUTO_TASKDIR # Add these to the class namespace for TEMP_TASKDIR = TEMP_TASKDIR # convenience. _all_task_tempdirs = [] _is_debug_enabled = False #=========================================================================== # Construction #===========================================================================
[docs] @classmethod def runFromCmdLine(cls): return cmdline.run_task_from_cmdline(cls)
[docs] @classmethod def fromJsonFilename(cls, filename): with open(filename) as f: json_dict = json.load(f) task = cls.fromJson(json_dict) return task
[docs] def initConcrete(self): super().initConcrete() self.statusChanged.connect(self.__onStatusChanged) self.failure_info = None self._taskdir = None self._taskdir_setting = self.DEFAULT_TASKDIR_SETTING self.calling_context = None self._in_preprocessing = False self._interruption_requested = False self._tempdir = None
[docs] def initializeValue(self): """ @overrides: parameters.CompoundParam """ if not self.name: self.name = self.__class__.__name__
#=========================================================================== # Abstract Methods #=========================================================================== INTERRUPT_ENABLED = False
[docs] def run(self): # Implementations of run are responsible for directly calling # `_finish` or connecting a signal to `_finish`. raise NotImplementedError()
[docs] def kill(self): """ Implementations are responsible for immediately stopping the task. No threads or processes should be running after this method is complete. This method should be called sparingly since in many contexts the task will be forced to terminate without a chance to clean up or free resources. """ raise NotImplementedError()
#=========================================================================== # Public API #===========================================================================
[docs] def start(self, skip_preprocessing=False): """ This is the main method for starting a task. Start will check if a task is not already running, run preprocessing, and then run the task. Failures in preprocessing will interrupt the task start, and the task will never enter the RUNNING state. :param skip_preprocessing: whether to skip preprocessing. This can be useful if preprocessing was already performed prior to calling start. :type skip_preprocessing: bool """ self.printDebug('start') if not self.isStartable(): raise RuntimeError( f"Can't start a task with status {self.status.name}") if not self.name: raise RuntimeError("Can't start a task with name: ''") self.status = Status.WAITING self._interruption_requested = False self.failure_info = None if not skip_preprocessing: with self.guard(): self.runPreprocessing(callback=self._processingCallback) if self.failure_info is not None: self.status = self.FAILED return self.status = self.RUNNING with self.guard(): self.run() if self.failure_info is not None: self.status = self.FAILED return
[docs] def wait(self, timeout=None): r""" Block until the task is finished executing or `timeout` seconds have passed. .. warning:: This should not be called directly from GUI code - see PANEL-18317. It is safe to call inside a subprocess or job. Run `git grep "task\.wait("` to see safe examples annotated with "# OK". :param timeout: Amount of time in seconds to wait before timing out. If None or a negative number, this method will wait until the task is finished. :type timeout: NoneType or int """ # Call the module-level wait function self.printDebug(f'wait({timeout})') try: with self.guard(): return _wait(self, timeout) finally: self.printDebug('wait done')
[docs] def isRunning(self): return self.status is self.RUNNING
[docs] def isStartable(self): return self.status in STARTABLE_STATUSES
[docs] def specifyTaskDir(self, taskdir_spec): """ Specify the taskdir creation behavior. Use one of the following options: A directory name (string). This may be a relative or absolute path None - no taskdir is requested. The task will use the CWD as its taskdir AUTO_TASKDIR - a new subdirectory will be created in the CWD using the task name as the directory name. TEMP_TASKDIR - a temporary directory will be created in the schrodinger temp dir. This directory is cleaned up when the task is deleted. :param taskdir_spec: one of the four options listed above """ if ((self._in_preprocessing and self._taskdir is not None) or self.isRunning()): raise RuntimeError('Taskdir specification may not be changed once ' 'the taskdir is created.') self._taskdir_setting = taskdir_spec self._taskdir = None
[docs] def taskDirSetting(self): """ Returns the taskdir spec. See specifyTaskDir() for details. """ return self._taskdir_setting
[docs] def getTaskDir(self): """ Returns the full path of the task directory. This is only available if the task directory exists (after creation of the taskdir or, if no task dir is specified, any time). """ if self._taskdir_setting is None: return os.getcwd() if isinstance(self._taskdir_setting, (str, pathlib.Path)): if os.path.exists(self._taskdir_setting): self._taskdir = os.path.abspath(self._taskdir_setting) if self._taskdir is None: raise TaskDirNotFoundError( 'Taskdir has not been created yet. Consider ' 'moving this call to an AFTER_TASKDIR ' 'preprocessor.') return self._taskdir
[docs] def getTaskFilename(self, fname): """ Return the appropriate absolute path for an input or output file in the taskdir. """ parent_dir = self.getTaskDir() return os.path.join(parent_dir, fname)
[docs] def addPreprocessor(self, func, order=None): """ Adds a preproceessor function to this task instance. If the function has been decorated with @preprocessor, the order specified by the decorator will be used as the default. :param func: the function to add :param order: the sorting order for the function relative to all other preprocessors. Takes precedence over order specified by the preprocessor decorator. :type order: float """ if order is None: decorated_order = funcchains.get_marked_func_order(func) if decorated_order is None: order = AFTER_TASKDIR else: order = decorated_order self.addFuncToGroup(func, preprocessor, order)
[docs] def addPostprocessor(self, func, order=0): """ Adds a postproceessor function to this task instance. If the function has been decorated with `@postprocessor`, the order specified by the decorator will be used. :param func: the function to add :type func: typing.Callable :param order: the sorting order for the function relative to all other preprocessors. Takes precedence over order specified by the preprocessor decorator. :type order: float """ self.addFuncToGroup(func, postprocessor, order)
[docs] def preprocessors(self): """ :return: A list of preprocessors (both decorated methods on the task and external functions that have been added via addPreprocessor) """ return self.getFuncGroup(preprocessor)
[docs] def postprocessors(self): """ :return: A list of postprocessors, both decorated methods on the task and external functions that have been added via `addPostprocessor()` :rtype: list[typing.Callable] """ return self.getFuncGroup(postprocessor)
[docs] def reset(self, *args, **kwargs): if not args and not kwargs: if self.status is self.RUNNING: raise RuntimeError("Can't reset a task while it's running") elif self.status is self.FAILED: self.failure_info = None super().reset(*args, **kwargs)
[docs] def replicate(self): """ Create a new task with the same input and settings (but no output) """ old_task = self new_task = self.__class__() new_task.specifyTaskDir(old_task.taskDirSetting()) old_preprocess_callbacks = old_task.getAddedFuncs(preprocessor) for func, order in old_preprocess_callbacks: new_task.addPreprocessor(func, order) for func, order in old_task.getAddedFuncs(postprocessor): new_task.addPostprocessor(func, order) if isinstance(new_task.input, parameters.CompoundParam): new_task.input.setValue(old_task.input) else: new_task.input = old_task.input return new_task
[docs] def isDebugEnabled(self): return self._is_debug_enabled
[docs] def printDebug(self, *args): if not self.isDebugEnabled(): return info = self.getDebugString() print(f'{info}:', *args)
[docs] def getDebugString(self): return f'{datetime.now()} {self.name}-{self.status.name}'
[docs] def requestInterruption(self): """ Request the task to stop. To enable this feature, subclasses should periodically check whether an interruption has been requested and terminate if it has been. If such logic has been included, `INTERRUPT_ENABLED` should be set to `True`. """ if not self.INTERRUPT_ENABLED: raise RuntimeError("Interruption is not enabled for this task.") self._interruption_requested = True
[docs] def isInterruptionRequested(self): return self._interruption_requested
#=========================================================================== # Internal methods #=========================================================================== @preprocessor(order=BEFORE_TASKDIR - 1000) def _validateTaskName(self): is_valid = fileutils.is_valid_jobname(self.name) if not is_valid: return False, fileutils.INVALID_JOBNAME_ERR % self.name def __copy__(self): task_copy = super().__copy__() if task_copy.status is task_copy.RUNNING: task_copy.status = task_copy.WAITING return task_copy def __deepcopy__(self, memo): task_copy = super().__deepcopy__(memo) if task_copy.status is task_copy.RUNNING: task_copy.status = task_copy.WAITING return task_copy def __eq__(self, other): """ Tasks compare equal if all params excluding the status are equal. """ is_eq = super().__eq__(other) if is_eq: return True else: if isinstance(other, self.__class__): self_copy = copy.copy(self) other_copy = copy.copy(other) return self_copy.toDict() == other_copy.toDict() return False def __onStatusChanged(self, status): if status is self.RUNNING: self.taskStarted.emit() elif status is self.FAILED: self.taskFailed.emit() elif status is self.DONE: self.taskDone.emit() def _processingCallback(self, result): if not result.passed: self._recordFailure(TaskFailure(result.message)) return result.passed def _defaultResultCallback(self, result): """ @overrides: funcchains.FuncChainMixin """ if not result.passed: raise TaskFailure(result.message) return True
[docs] @typing.final def runPreprocessing(self, callback=None, calling_context=None): """ Run the preprocessors one-by-one. By default, any failing preprocessor will raise a TaskFailure exception and terminate processing. This behavior may be customized by supplying a callback function which will be called after each preprocessor with the result of that preprocessor. This method is "final" so that all preprocessing logic will be enclosed in the try/finally block. :param callback: a function that takes result and returns a bool that indicates whether to continue on to the next preprocessor :param calling_context: specify a value here to indicate the context in which this preprocessing is being called. This value will be stored in an instance variable, self.calling_context, which can be accessed from any preprocessor method on this task. Typically this value will be either self.GUI, self.CMDLINE, or None, but any value may be supplied here and checked for in the preprocessor methods. self.calling_context always reverts back to None at the end of runPreprocessing. """ self.printDebug('runPreprocessing') self._in_preprocessing = True self._taskdir = None self.calling_context = calling_context try: return self.processFuncChain(preprocessor, result_callback=callback) finally: self.calling_context = None self._in_preprocessing = False self.printDebug('done preprocessing')
def _runPostprocessing(self, callback=None): return self.processFuncChain(postprocessor, result_callback=callback) def _makeTempTaskDir(self): parent_dir = fileutils.get_directory_path(fileutils.TEMP) self._tempdir = tempfile.TemporaryDirectory(dir=parent_dir) self._taskdir = self._tempdir.name self._registerTempDir(self._tempdir) def _registerTempDir(self, tmpdir): """ Register a tempdir to the class. This is used to clean up all tempdirs in unit tests. """ self._all_task_tempdirs.append(tmpdir) def _makeDir(self, taskdir): os.makedirs(taskdir) @preprocessor(order=_TASKDIR_ORDER) def _createTaskDir(self): """ Create a task directory for running the task in. """ if self._taskdir_setting is TEMP_TASKDIR: self._makeTempTaskDir() return True cwd = os.getcwd() if self._taskdir_setting is None: self._taskdir = cwd return True if self._taskdir_setting is AUTO_TASKDIR: taskdir = os.path.abspath(self.name) else: taskdir = os.path.abspath(self._taskdir_setting) self._taskdir = taskdir try: self._makeDir(taskdir) except FileExistsError: if self._taskdir_setting is not AUTO_TASKDIR: # Allow specified path to already exist return True if self.calling_context is self.GUI: return (True, f'Task directory {self._taskdir} already exists. ' 'Contents will be overwritten. Continue?') return False, f"Task directory {self._taskdir} already exists." return True def _recordFailure(self, exception, exc_traceback_str=None): """ Store the exception in `failure_info` and set status to failed """ if self.failure_info is not None: return message = str(exception) self.failure_info = FailureInfo(exception=exception, traceback=exc_traceback_str, message=message) if exc_traceback_str: tb = exc_traceback_str else: tb = '' print( f'{tb}{repr(self)}> failed: {type(exception).__name__}("{message}")' )
[docs] @contextlib.contextmanager def guard(self): """ Context manager that saves any Exception raised inside """ try: yield except Exception: err_type, exc_value, exc_traceback = sys.exc_info() if err_type is TaskFailure: exc_traceback_str = None else: exc_traceback_str = ''.join( traceback.format_tb(exc_traceback)[-10:]) # We have to delete the traceback to prevent a circular ref. # See the `traceback` module documentation for additional info. del exc_traceback self._recordFailure(exc_value, exc_traceback_str)
def _finish(self): self.printDebug('_finish') if self.failure_info is not None: self.status = Status.FAILED return with self.guard(): self._runPostprocessing(callback=self._processingCallback) if self.failure_info is not None: self.status = Status.FAILED return self.status = Status.DONE def __repr__(self): if self.isAbstract(): return super().__repr__() return (f'<{self.__class__.__name__}: {self.name} - ' f'{Status(self.status).name}>') # sometimes status is an int @classmethod def _populateClassParams(cls): cls._convertNestedClassToDescriptor('Input', 'input') cls._convertNestedClassToDescriptor('Output', 'output') super()._populateClassParams() @classmethod def _convertNestedClassToDescriptor(cls, nested_class_name, descriptor_name): """ If a nested class of the specified name is defined, this method will instantiate that class and set that instance as a class variable. Ex: class Foo: class Bar: pass Calling Foo._convertNestedClassToDescriptor('Bar', 'bar') will do the equivalent of putting bar = Bar() inside the Foo class. Typically used to instatiate Param classes as descriptors on the class. :param nested_class_name: the name of the class to look for :param descriptor_name: the name that the descriptor instance to be added to the class. """ if nested_class_name in cls.__dict__: nested_class = getattr(cls, nested_class_name) desc = nested_class() desc.__set_name__(cls, descriptor_name) setattr(cls, descriptor_name, desc)
#=============================================================================== # Task interfaces #=============================================================================== class _AbstractFunctionTask(AbstractTask): def run(self): self._runMainFunction() def _guardedMain(self): with self.guard(): self.mainFunction() def _runMainFunction(self): raise NotImplementedError() def mainFunction(self): raise NotImplementedError()
[docs]class AbstractCmdTask(AbstractTask):
[docs] def run(self): cmd = self.makeCmd() for idx, arg in enumerate(cmd): if not isinstance(arg, str): msg = (f"makeCmd() must return a string of lists. Item {idx} " f"is type {type(arg)}.") raise ValueError(msg) self.runCmd(cmd)
[docs] def runCmd(self, cmd): raise NotImplementedError()
[docs] def makeCmd(self): return []
[docs]class AbstractComboTask(AbstractCmdTask, _AbstractFunctionTask): """ Subclasses should only define params inside of input or output. Top-level params defined in subclasses do NOT get serialized between the frontend and backend task instances. Thus, any modifications of new top-level params in the backend (i.e. mainFunction) will not have any effect on the rehydrated frontend task. """ _run_as_backend: bool = False ENTRYPOINT = 'combotask_entry_point.py' # Private params, not for use by child classes _task_module: str _task_class: str _task_script: str _failure_info: str = None _failure_tb: str = None _combo_id: str = None # Only these params will be serialized in frontend/backend conversions _FRONTEND_TO_BACKEND_PARAMS = [ 'name', 'input', '_run_as_backend', '_task_module', '_task_class', '_task_script', '_combo_id' ] _BACKEND_TO_FRONTEND_PARAMS = [ 'output', 'status', '_run_as_backend', '_failure_info', '_failure_tb' ] def _regenerateComboId(self): """ Generate a new combo id for this task. A combo id is a random string that is used to prevent tasks with the same task name from overwriting each other's combo files (i.e. _frontend.json and _backend.json). """ alphabet = string.ascii_lowercase + string.digits self._combo_id = ''.join(random.choices(alphabet, k=12))
[docs] def initializeValue(self): super().initializeValue() if self._combo_id is None: # no combo id from rehydrated json file self._regenerateComboId()
@property def json_filename(self): return self.getTaskFilename( f'.{self.name}_{self._combo_id}_frontend.json') @property def json_out_filename(self): return self.getTaskFilename( f'.{self.name}_{self._combo_id}_backend.json')
[docs] def start(self, *args, **kwargs): """ @overrides: AbstractTask """ if self.isBackendMode(): return self.runBackend() super().start(*args, **kwargs)
[docs] def isBackendMode(self): return self._run_as_backend
[docs] def makeCmd(self): """ @overrides: AbstractCmdTask """ cmd = [ get_schrodinger_run(), self.ENTRYPOINT, '--task_json', self._getFrontEndJsonArg() ] return cmd
def _getFrontEndJsonArg(self): return self.json_filename def _writeFrontendJsonFile(self): task_module = self._get_module() backend_task = copy.deepcopy(self) # deepcopy of a compoundparam only copies params backend_task._taskdir = self._taskdir if task_module == '__main__': print(f'{self} is defined outside the build. Will attempt to copy ' 'script to backend dir to run. If the script needs to import ' 'other files, the task will still fail. In this case, move ' 'the script and its dependencies to an importable location.') cp_filename = self._copyScriptToBackend() backend_task._task_script = os.path.basename(cp_filename) backend_task._task_module = task_module backend_task._task_class = type(self).__name__ # need to get json_filename before setting _run_as_backend to True json_filename = self.json_filename backend_task._processTaskFilesForFrontendWrite() backend_task._run_as_backend = True backend_task._writeComboJsonFile(json_filename) def _copyScriptToBackend(self): script_filename = inspect.getfile(type(self)) try: return shutil.copy(script_filename, self.getTaskDir()) except shutil.SameFileError: return script_filename @preprocessor(order=_WRITE_JSON_ORDER) def _prepareComboTask(self, *args, **kwargs): self._writeFrontendJsonFile() def _finish(self): super()._finish() # The next time this task is started, it should have a new combo id self._regenerateComboId()
[docs] def backendMain(self): raise NotImplementedError
def _processBackend(self): json_out_path = self.json_out_filename if not os.path.isfile(json_out_path): msg = "No json file was returned from the backend. " logfile = self.getTaskFilename(self._getLogFilename()) if os.path.isfile(logfile): msg += f"Check {logfile} for more information." self.printDebug(f'Log file contents:\n{self.getLogAsString()}') else: msg += f"Log file not found at {logfile}" exception = RuntimeError(msg) self._recordFailure(exception) else: with open(json_out_path, 'r') as infile: # Create a new instance from the backend json output TaskClass = type(self) try: rehydrated_backend = TaskClass.fromJson(json.load(infile)) except json.JSONDecodeError as e: self._recordFailure(e) else: self._updateFromBackend(rehydrated_backend) def _updateFromBackend(self, rehydrated_backend): """ Update the frontend task based on the rehydrated backend task """ if isinstance(self.output, parameters.CompoundParam): self.output.setValue(rehydrated_backend.output) self._processTaskFilesForBackendRehydration() else: self.output = rehydrated_backend.output if rehydrated_backend.status == rehydrated_backend.FAILED: backend_exc = pickle.loads( rehydrated_backend._failure_info.encode()) backend_tb = rehydrated_backend._failure_tb self._recordFailure(backend_exc, backend_tb) def _writeComboJsonFile(self, filename): if self.status is self.FAILED: # Use protocol 0 since it's ascii-encodable self._failure_info = pickle.dumps(self.failure_info.exception, 0).decode() backend_tb = ''.join( traceback.format_tb(self.failure_info.exception.__traceback__)) self._failure_tb = backend_tb ser_task = self._createSerializationTask() try: with open(filename, 'w') as f: json.dump(ser_task, f, indent=4) except: # If something goes wrong during serialization, we should make # sure to remove the empty json file. os.remove(filename) raise
[docs] def runBackend(self): self._processTaskFilesForBackendExecution() self.progressChanged.connect(self._onBackendProgressChanged) self.max_progressChanged.connect(self._onBackendProgressChanged) self.progress_stringChanged.connect(self._onBackendProgressChanged) with self.guard(): try: self.backendMain() except NotImplementedError: self.mainFunction() with self.guard(): self._processTaskFilesForBackendWrite() if self.failure_info: self.status = self.FAILED if not isinstance(self.failure_info.exception, TaskFailure): print(self.failure_info.traceback) if self.failure_info.message: print(self.failure_info.message) # Mark as frontend to ensure correct params are serialized self._run_as_backend = False self._writeComboJsonFile(self.json_out_filename)
def _onBackendProgressChanged(self): """ Implement logic that will communicate progress change from the backend to the front-end. """ def _get_module(self): """ Return the module string defining where the class for `self` is defined. """ return imputils.get_path_from_module(inspect.getmodule(self)) def _createSerializationTask(self) -> 'AbstractComboTask': """ Return a new instance of this task that has serialization param values set for frontend/backend conversion. Non-serialization params have default values. """ ser_task = self.__class__() ser_param_names = self._getSerializationParamNames() for param_name in ser_param_names: param_value = getattr(self, param_name) if isinstance(param_value, parameters.CompoundParam): param_to_serialize = getattr(ser_task, param_name) param_to_serialize.setValue(param_value) else: setattr(ser_task, param_name, param_value) return ser_task def _getSerializationParamNames(self) -> List[str]: """ Return a list of the names of params that should be serialized for frontend/backend combo task conversion. """ if self._run_as_backend: param_names = self._FRONTEND_TO_BACKEND_PARAMS else: param_names = self._BACKEND_TO_FRONTEND_PARAMS return param_names #=========================================================================== # TaskFile Processing #=========================================================================== def _processTaskFilesForFrontendWrite(self): """ This will be called before writing out the combotask frontend json file. Transforms all TaskFile and TaskFolder paths in self.input so that the json file within the taskdir will be portable, if possible. Raises a ValueError if any files/directories do not exist. """ def process_input(path, launchdir): path = os.path.abspath(path) return path self._assertTaskFileExistence(self.input) self._processTaskFiles(self.input, process_func=process_input) def _processTaskFilesForBackendExecution(self): """ This will be called in the backend before executing the mainFunction of the combotask. Override if the file paths are different in the backend compared to the paths used in the frontend. Raises a ValueError if any files/directories do not exist. """ self._assertTaskFileExistence(self.input) def _processTaskFilesForBackendWrite(self): """ This will be called in the backend after the mainFunction returns before writing the combotask backend json file. Converts absolute paths into relative paths so that file references can remain valid if the taskdir is copied or moved. Raises a ValueError if any files/directories do not exist. """ def process_output(path, launchdir): path = os.path.relpath(path) return path self._assertTaskFileExistence(self.output) self._processTaskFiles(self.output, process_func=process_output) def _processTaskFilesForBackendRehydration(self): """ This will be called before the output of the backend task is set back on the frontend task. Raises a ValueError if any files/directories do not exist. """ self._assertTaskFileExistence(self.output) def _assertTaskFileExistence(self, param): def assert_taskfile_existence(path): if path is None: return None if not os.path.exists(path): raise ValueError( f'Filepath "{path}" does not exist. Make sure all ' 'taskfiles and task folders point to existing files before ' 'starting or completing the task.') return path if isinstance(param, parameters.CompoundParam): paramtools.map_subparams(assert_taskfile_existence, param, TaskFile) if isinstance(param, parameters.CompoundParam): paramtools.map_subparams(assert_taskfile_existence, param, TaskFolder) def _processTaskFiles(self, param, *, process_func, dir=None): if dir is None: dir = self.getTaskDir() def process_taskfile(path): if path is None: return None if process_func is None: return path else: new_path = process_func(path, dir) return new_path if isinstance(param, parameters.CompoundParam): paramtools.map_subparams(process_taskfile, param, TaskFile) if isinstance(param, parameters.CompoundParam): paramtools.map_subparams(process_taskfile, param, TaskFolder)
#=============================================================================== # Task execution mixins #===============================================================================
[docs]def get_schrodinger_run(): return 'run'
class _SaveTaskReferenceMixin: def __init_subclass__(cls): super().__init_subclass__() # Let each class have its own set so failures are easier to understand cls._saved_task_references = scollections.IdSet() def start(self, *args, **kwargs): super().start(*args, **kwargs) if self.status == Status.RUNNING: self._saveTaskReference() def _finish(self): super()._finish() self._discardTaskReference() def _saveTaskReference(self): self._saved_task_references.add(self) def _discardTaskReference(self): self._saved_task_references.discard(self)
[docs]class BlockingMixin: """ Compatible with subclasses of AbstractFunctionTask. """ def _runMainFunction(self): self._guardedMain() self._finish()
[docs]class ThreadMixin(_SaveTaskReferenceMixin): MAX_THREAD_TASKS = 500 qthread = parameters.NonParamAttribute()
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.qthread = None
[docs] def kill(self): """ @overrides: AbstractTask Killing threads is dangerous and can leading to deadlocking on Windows, so we intentionally leave it unimplemented rather than using QThread.terminate. """ raise NotImplementedError
def _runMainFunction(self): # Make sure that there is a QApplication running. If there isn't, # create a QCoreApplication. application.get_application(create=True, use_qtcore_app=True) self.qthread = QtCore.QThread() # TODO: Decide whether to leave this as a monkey-patch or hook up # qthread.started to _guardedMain instead. If we leave it as a patch, # we should add a strong warning against calling .start() from multiple # threads. self.qthread.run = self._guardedMain self.qthread.finished.connect(self.__onThreadFinished) self.qthread.start() @typing.final def __onThreadFinished(self): self._finish()
[docs]class QProcessError(Exception):
[docs] def __init__(self, message): super().__init__(message)
[docs]class QProcessFailedToStartError(QProcessError): pass
[docs]class QProcessCrashedError(QProcessError): pass
[docs]class QProcessTimedout(QProcessError): pass
[docs]class QProcessWriteError(QProcessError): pass
[docs]class QProcessReadError(QProcessError): pass
[docs]class QProcessUnknownError(QProcessError): pass
_QProcessErrorToException = { QProcess.FailedToStart: QProcessFailedToStartError, QProcess.Crashed: QProcessCrashedError, QProcess.Timedout: QProcessTimedout, QProcess.WriteError: QProcessWriteError, QProcess.ReadError: QProcessReadError, QProcess.UnknownError: QProcessUnknownError }
[docs]class SubprocessMixin(_SaveTaskReferenceMixin): cmd = parameters.NonParamAttribute() exit_code = parameters.NonParamAttribute() qprocess = parameters.NonParamAttribute()
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cmd = None self.exit_code = None self.qprocess = None self._printing_output_to_terminal = False
[docs] def printingOutputToTerminal(self): """ :return: whether the `StdOut` and `StdErr` output from this task is being printed to the terminal :rtype: bool """ return self._printing_output_to_terminal
[docs] def setPrintingOutputToTerminal(self, print_to_terminal): """ Set this task to print `StdOut` and `StdErr` output to terminal, or not. :param print_to_terminal: whether to send process output to terminal :type print_to_terminal: bool """ self._printing_output_to_terminal = print_to_terminal
[docs] def runCmd(self, cmd): # Make sure that there is a QApplication running. If there isn't, # create a QCoreApplication. application.get_application(create=True, use_qtcore_app=True) self.exit_code = None self.qprocess = None self.cmd = cmd cmd[0] = subprocess_utils.abs_schrodinger_path(cmd[0]) self._setupQProcess() self.qprocess.start(cmd[0], cmd[1:])
def _setupQProcess(self): self.qprocess = QtCore.QProcess() if self.printingOutputToTerminal(): self.qprocess.setProcessChannelMode( QtCore.QProcess.ForwardedChannels) else: self.qprocess.setProcessChannelMode(QtCore.QProcess.MergedChannels) self.qprocess.setStandardOutputFile(self._getLogFilename()) self.qprocess.setWorkingDirectory(self.getTaskDir()) self.qprocess.finished.connect(self.__onSubprocessCompleted) self.qprocess.errorOccurred.connect(self.__onErrorOccurred) @typing.final def __onSubprocessCompleted(self): with self.guard(): self._onSubprocessCompleted() self._finish() def _onSubprocessCompleted(self): self.exit_code = self.qprocess.exitCode() if self.exit_code != 0: msg = f'{self} returned non-zero exit code.' log_str = self.getLogAsString() if len(log_str) > 200: # Elide log_str = log_str[:200] + '...' msg += f'\n{log_str}' self._recordFailure(TaskFailure(msg)) @typing.final def __onErrorOccurred(self, error): with self.guard(): self._onErrorOccurred(error) self._finish() def _onErrorOccurred(self, error): qprocess_exception = _QProcessErrorToException[error]( message= f"Command: {self.cmd} had fatal error: {self.qprocess.errorString()}" ) self.exit_code = self.qprocess.exitCode() self._recordFailure(qprocess_exception) def _getLogFilename(self): return self.getTaskFilename(self.name + '.log')
[docs] def getLogAsString(self) -> str: log_fn = self.getTaskFilename(self._getLogFilename()) if not os.path.isfile(log_fn): return f'Log file not found: {log_fn}' with open(log_fn) as log_file: return log_file.read()
[docs] def kill(self): """ @overrides: AbstractTask Kill the subprocess and set the status to FAILED. """ if self.status is not self.RUNNING: raise RuntimeError("Can't kill a task that's not running.") if self.qprocess: self.qprocess.finished.disconnect(self.__onSubprocessCompleted) self.qprocess.errorOccurred.disconnect(self.__onErrorOccurred) self.qprocess.kill() self.qprocess.waitForFinished() self._recordFailure(TaskKilled()) self._finish()
#=============================================================================== # Prepackaged Task Classes #===============================================================================
[docs]class BlockingFunctionTask(BlockingMixin, _AbstractFunctionTask): """ A task that simply runs a function and blocks for the duration of it. To use, implement `mainFunction`. """
[docs]class ThreadFunctionTask(ThreadMixin, _AbstractFunctionTask): """ A task that runs a function in a separate thread. To use, implement `mainFunction`. Note: this class should not be used except in limited circumstances, as much of our internal code is not thread safe (e.g. structure.Structure - see PANEL-16783). New implementations will have to register their usage in test_thread_usage.py, and include the following warning in the mainFunction of the task: # This logic will be run in a worker thread and must not # access thread-unsafe libraries, including structure.Structure. """
[docs]class SubprocessCmdTask(SubprocessMixin, AbstractCmdTask): """ A task that launches a subprocess. To use, implement `makeCmd` and return a list of strings. """
[docs]class ComboBlockingFunctionTask(AbstractComboTask): """ This is mostly for testing purposes. """
[docs] def runCmd(self, cmd): cls = type(self) backend_task = cls.fromJsonFilename(self.json_filename) backend_task.specifyTaskDir(self.getTaskDir()) backend_task.start() os.rename(backend_task.json_out_filename, self.json_out_filename) self._processBackend() self._finish()
[docs]class ComboSubprocessTask(SubprocessMixin, AbstractComboTask): """ A task that runs a function in a subprocess. To use, implement `mainFunction`. """ def _processTaskFilesForBackendRehydration(self): def process_input(path, backend_dir): return self.getTaskFilename(path) self._processTaskFiles(self.output, process_func=process_input) super()._processTaskFilesForBackendRehydration()
[docs] def runBackend(self): # Specify the task dir as the cwd since we've already chdirs into # the directory with all the task files self.specifyTaskDir(None) return super().runBackend()
[docs] def getTaskDir(self): if self.isBackendMode(): return '' return super().getTaskDir()
def _finish(self): with self.guard(): self._processBackend() super()._finish()
[docs]class SignalTask(AbstractTask): """ A task that relies on signals to proceed. Runs asynchronously via the event loop without requiring a worker thread. To use, implement setUpMain to connect any per-run signals and slots. Any slots should be decorated with SignalTask.guard_method so that exceptions in slots get converted into task failures. To end the task, emit self.mainDone to indicate the task has successfully completed. To fail, raise a TaskFailure or other exception. """ mainDone = QtCore.pyqtSignal()
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.mainDone.connect(self._finish)
[docs] @staticmethod def guard_method(func): def wrapped_func(self, *args, **kwargs): with self.guard(): return func(self, *args, **kwargs) if self.failure_info: self._finish() return wrapped_func
[docs] def run(self): with self.guard(): self.setUpMain()
[docs] def setUpMain(self): raise NotImplementedError()