import inspect
import traceback
from typing import Optional, List, Iterable
from schrodinger.application.desmond.arkdb import ArkDbGetError, ArkDb, Datum
from schrodinger.application.desmond import util
[docs]class SubtaskExecutionError(RuntimeError):
pass
[docs]class Task:
"""
This is a base class. An instance of this class defines a concrete task to
be executed. All subclasses are expected to implement the `__init__` and the
`execute` methods. The `execute` should be either a public callable attribute
or a public method. See `ParchTrajectoryForFepLambda` below for example.
A task can be composed of one or more subtasks. The relationship among the
premises of this task and its subtasks is the following:
- If this task's premises are not met, no subtasks will be executed.
- Failure of one subtask will NOT affect other subtasks being executed.
Six public attributes/properties:
- name: An arbitrary name for the task. Useful for error logging.
- is_completed - A boolean value indicating if the particular task has been
completed successfully.
- results - A list of `Datum` objects as the results of the execution of the
task. The data will be automatically put into the dababase.
- log - A list of strings recording the error messages (if any) during the
last execution of the task. The list is empty if there was no errors at
all.
- premises - A list of lists of `Premise` objects. The first list are the
premises of this `Task` object, followed by that of the first subtask,
and then of the second subtask, and so on. Each element list can be empty.
- options - Similar to `premises` except that the object type is `Option`.
"""
[docs] def __init__(self, name: str, subtasks: Optional[List] = None):
"""
:param name: An arbitrary name. Useful for error logging.
"""
self.name = name
self.is_completed = False
self.results = []
self.errlog = []
# A task can be composed of a list of subtasks.
self._subtasks = subtasks or []
# List of lists
self._premises = None
self._options = None
def __str__(self):
s = [
"%s: %s" % (self.name, type(self).__name__),
" Completed: %s" % ((self.is_completed and "yes") or "no"),
" Results's Keys: %s" %
", ".join(e.key for e in self.results if isinstance(e, Datum)),
" Log: %s" %
(self.errlog and
("\n " + "\n ".join(self.errlog)) or "(no errors)")
]
return "\n".join(s)
@property
def premises(self):
if self._premises is None:
signature = inspect.signature(self.execute)
self._premises = [[(name, param.annotation)
for name, param in signature.parameters.items()
if isinstance(param.annotation, Premise)]
] + [sub.premises for sub in self._subtasks]
return self._premises
@property
def options(self):
if self._options is None:
signature = inspect.signature(self.execute)
self._options = [[(name, param.annotation)
for name, param in signature.parameters.items()
if isinstance(param.annotation, Option)]
] + [sub.options for sub in self._subtasks]
return self._options
[docs] def clear(self):
"""
Cleans the state of this object for a new execution.
"""
self.is_completed = False
self.results = []
self.errlog = []
[docs] def execute(self, db: ArkDb):
"""
Executes this task. This should only be called after all premises of
this task are met. The premises of the subtasks are ignored until the
subtask is executed. Subclasses should implement an `execute`, either as
an instance method, or as an instance's public callable attribute.
After execution, all results desired to be put into the database should
be saved as the `results` attribute.
The first argument of `execute` should always be for the database.
"""
if self._subtasks:
if not execute(db, self._subtasks):
self.errlog = collect_logs(self._subtasks)
raise SubtaskExecutionError("Subtask execution failed.")
[docs]class ParchTrajectoryForSolubilityFep(Task):
"""
Task to parch the trajectory for the given FEP lambda state. The lambda
state is represented by 0 and 1.
Results are all `Datum` objects:
- key = "ResultLambda{fep_lambda}.ParchedTrajectoryFileName", where
`{fep_lambda}` is the value of the lambda state.
- val = Name of the parched trajectory file
"""
[docs] def __init__(self,
name,
cms_fname_pattern: str,
trj_fname_pattern: str,
out_bname_pattern: str,
num_solvent: int = 200):
"""
The values of the arguments: `cms_fname_pattern`, `trj_fname_pattern`,
and `out_bname_pattern`, are simple strings that specify f-string
patterns to be evaluated yet to get the corresponding file names.
Example, `"{jobname}_replica_{index}-out.cms"`, note that it's a simple
string and uses two f-string variables `{jobname}` and `{index}`. The
values of the f-string variables will be obtained on the fly when the
task is executed. Currently, the following f-string variables are
available for this task::
{jobname} - The FEP job's name
{index} - The index number of the replica corresponding to either
the first lambda window or the last one, depending on
the value of the `fep_lambda` argument.
"""
super().__init__(name)
def execute(
_,
jobname: Premise(
"Keywords[i].FEPSimulation.JobName"), # noqa: F821
dew_asl: Premise(
"Keywords[i].ResultLambda1.LigandASL"), # noqa: F821
replica: Premise("Keywords[i].Replica") # noqa: F821
):
from schrodinger.application.desmond.packages import parch
num_win = len(replica)
index = num_win - 1
cms_fname = eval(f"f'{cms_fname_pattern}'")
cms_fname = util.gz_fname_if_exists(cms_fname)
cms_fname = util.verify_file_exists(cms_fname)
trj_fname = util.verify_traj_exists(eval(f"f'{trj_fname_pattern}'"))
out_bname = eval(f"f'{out_bname_pattern}'")
# yapf: disable
cmd = util.commandify([
cms_fname, trj_fname, out_bname,
['-output-trajectory-format', 'auto'],
['-dew-asl', dew_asl],
['-n', num_solvent] ])
# yapf: enable
out_cms_fname, out_trj_fname = parch.main(cmd)
self.results = [
Datum("Keywords[i].ResultLambda1.ParchedCmsFname",
out_cms_fname),
Datum("Keywords[i].ResultLambda1.ParchedTrjFname",
out_trj_fname),
]
self.execute = execute
[docs]class ParchTrajectoryForFepLambda(Task):
"""
Task to parch the trajectory for the given FEP lambda state. The lambda
state is represented by 0 and 1.
Results are all `Datum` objects:
- key = "ResultLambda{fep_lambda}.ParchedTrajectoryFileName", where
`{fep_lambda}` is the value of the lambda state.
- val = Name of the parched trajectory file
We leave this class here (1) to explain how the framework basically works
and (2) to demonstrate how to create a concrete `Task` subclass.
- Introduction
From the architectural point of view, one of the common and difficult
issues in computation is perhaps data coupling: Current computation needs
data produced by previous ones. It's difficult because the coupling is
implicit and across multiple programming units/modules/files, which often
results in bugs when code change in one place implicitly breaks code
somewhere else.
Taking this class as an example, the task is trivial when explained at
the conceptual level: Call the `trj_parch.py` script with properly set
options to generated a "parched" trajectory. But when we get to the detail
to incorporate this task in a workflow, it becomes very complicated,
mostly because of the data coupling issue (which is the devil here): From
the view point of this task, we have to check the following data
dependencies:
1. The input files (the output CMS file and the trajectory file) exist.
2. We identify the input files by file name patterns that depend on the
current jobname which is supposed to be stored in a (.sid) data file.
So we have to ensure the jobname exists in the database.
(Alternatively, we can pass the jobname through a series of function
calls, but we won't discuss about the general issues of that approach)
3. To call trj_parch.py, we must set the `-dew-asl` and `-fep-lambda`
options correctly. The value for these options are either stored in
.sid data file or passed into this class via an argument of the
`__init__` method.
Furthermore, when any of these conditions are not met, informative errors
messages must be logged.
All of these used to force the developer to write a LOT of biolerplate
code to get/put data from the database, to check these conditions, and to
log all errors, for even the most conceptually trivial task. So often than
not, such boring (and repeated) code is either incomplete or not in place
at all. And we take the risk of doing computations without verifying the
data dependencies, until some code changes break one of the conditions.
- Four types of data
We must realize where the coupling comes into the architecture of our
software. For this, it helps to categorize data into the following types
in terms of the source of the data:
1. Hard coded data:
- This type of data is hard coded and rarely needs to be modified
customized. Example, `num_solvent=200`.
2. Arguments:
- Data passed into the function by the caller code. Example,
`fep_lambda`.
3. From the database:
- Examples: jobname, ligand ASL, number of lambda windows.
4. Assumptions:
- Assumptions are data generated by previous stages in a workflow but
are out of the control of the task of interest.
For example, we have to assume the CMS and trajectory files following
certain naming patterns exist in the file system. In theory, the less
assumptions, the more robust the code. But in practice, it is very
difficult (if not impossible) to totally avoid assumptions.
Implicit data coupling happens for the types (3) and (4) data.
- The task framework
The basic idea of this framework is to make the types (3) and (4) data
more explicitly and easily defined in our code, which will then make it
possible to automatically check their availabilities and log errors.
For the type (3) data, we provide `Premise` and `Option` classes for
getting the data.
For the type (4) data, we have to rely on a convention to verify the
assumpations. But utility functions are provided to make that easier and
idiomatic.
In both cases, when the data are unavailable, informative error messages
will be automatically logged.
The goal of this framework is to relieve the developer from writing a lot
of biolerplate code and shift their attentions to writing reusable tasks.
"""
[docs] def __init__(self,
name,
fep_lambda: int,
result_lambda: int,
cms_fname_pattern: str,
trj_fname_pattern: str,
out_bname_pattern: str,
num_solvent: int = 200):
"""
The values of the arguments: `cms_fname_pattern`, `trj_fname_pattern`,
and `out_bname_pattern`, are simple strings that specify f-string
patterns to be evaluated yet to get the corresponding file names.
Example, `"{jobname}_replica_{index}-out.cms"`, note that it's a simple
string and uses two f-string variables `{jobname}` and `{index}`. The
values of the f-string variables will be obtained on the fly when the
task is executed. Currently, the following f-string variables are
available for this task::
{jobname} - The FEP job's name
{fep_lambda} - Same value as that of the argument `fep_lambda`. It's
either 0 or 1.
{result_lambda} - Same value as that of the argument `result_lambda`. It's
either 0 or 1
{index} - The index number of the replica corresponding to either
the first lambda window or the last one, depending on
the value of the `fep_lambda` argument.
"""
super().__init__(name)
# Because the `execute` depends on the arguments of the `__init__`
# method so we define `execute` on the fly.
# It's possible to define `execute` as an instance method. But then we
# need to save the `cms_fname_pattern`, etc. arguments, which are not
# used elsewhere. It's less verbose to define `execute` as a callable
# attribute.
# yapf: disable
def execute(_,
jobname: Premise("Keywords[i].FEPSimulation.JobName"), # noqa: F821
dew_asl: Premise(f"Keywords[i].ResultLambda{result_lambda}.LigandASL"), # noqa: F821,F722
replica: Premise("Keywords[i].Replica"), # noqa: F821
ref_mae: Option("ReferenceStruct") # noqa: F821
):
# yapf: enable
"""
We define three `Premise`s for `execute`. Each of them refers to
a datum keyed by the corresponding string in the database.
The `Premise`s will be checked against the present database by the
module-level `execute` function below. If any of these `Premise`s
are not met, an error will be recorded, and this `execute` function
will not be called.
"""
from schrodinger.application.desmond.packages import parch
num_win = len(replica)
index = fep_lambda and (num_win - 1)
cms_fname = eval(f"f'{cms_fname_pattern}'")
cms_fname = util.gz_fname_if_exists(cms_fname)
cms_fname = util.verify_file_exists(cms_fname)
trj_fname = util.verify_traj_exists(eval(f"f'{trj_fname_pattern}'"))
out_bname = eval(f"f'{out_bname_pattern}'")
# yapf: disable
cmd = util.commandify([
cms_fname, trj_fname, out_bname,
['-output-trajectory-format', 'auto'],
['-dew-asl', dew_asl],
['-n', num_solvent],
['-fep-lambda', fep_lambda],
['-ref-mae', ref_mae]])
# yapf: enable
out_cms_fname, out_trj_fname = parch.main(cmd)
result_field = f"Keywords[i].ResultLambda{result_lambda}"
self.results = [
Datum(f"{result_field}.ParchedCmsFname", out_cms_fname),
Datum(f"{result_field}.ParchedTrjFname", out_trj_fname),
]
self.execute = execute
[docs]class ParchTrajectoryForFep(Task):
"""
Task to generate parched trajectories for both FEP lambda states. The lambda
state is represented by 0 and 1.
Results are all `Datum` objects:
- key = "ResultLambda0.ParchedCmsFname"
- val = Name of the parched CMS file for lambda state 0: "lambda0-out.cms"
- key = "ResultLambda1.ParchedCmsFname"
- val = Name of the parched CMS file for lambda state 1: "lambda1-out.cms"
- key = "ResultLambda0.ParchedTrjFname"
- val = Name of the parched trajectory file for lambda state 0::
"lambda0{ext}", where "{ext}" is the same extension of the input
trajectory file name.
- key = "ResultLambda1.ParchedTrjFname"
- val = Name of the parched trajectory file for lambda state 1::
"lambda0{ext}", where "{ext}" is the same extension of the input
trajectory file name.
We leave this class here to demonstrate how to define a concrete `Task`
subclass by composition.
"""
[docs] def __init__(self, name, num_solvent=200):
# Hardcodes the file name patterns, which are not expected to change.
# If different patterns are used, create a new `Task`'s subclass similar
# to this one.
cms_fname_pattern = "{jobname}_replica{index}-out.cms"
trj_fname_pattern = "{jobname}_replica{index}"
out_bname_pattern = "lambda{fep_lambda}"
args = [
cms_fname_pattern, trj_fname_pattern, out_bname_pattern, num_solvent
]
super().__init__(name, [
ParchTrajectoryForFepLambda(name + "_lambda0", 0, 0, *args),
ParchTrajectoryForFepLambda(name + "_lambda1", 1, 1, *args)
])
[docs]class ParchTrajectoryForAbsoluteFep(Task):
"""
Task to generate the parched trajectory for the
lambda state with the fully-interacting ligand.
Results are all `Datum` objects:
- key = "ResultLambda0.ParchedCmsFname"
- val = Name of the parched CMS file: "lambda0-out.cms"
- key = "ResultLambda0.ParchedTrjFname"
- val = Name of the parched trajectory file::
"lambda0{ext}", where "{ext}" is the same extension of the input
trajectory file name.
"""
[docs] def __init__(self, name, num_solvent=200):
cms_fname_pattern = "{jobname}_replica0-out.cms"
trj_fname_pattern = "{jobname}_replica0"
out_bname_pattern = "lambda0"
args = [
cms_fname_pattern, trj_fname_pattern, out_bname_pattern, num_solvent
]
# Absolute binding calculations are set up such that
# the mutant structure of replica 0 contains the fully
# interacting ligand. Parch will remove the reference
# structure (dummy particle) but keep the mutant structure
# (ligand) when fep_lambda=0.
# Results are reported to fep_lambda=0, to make it consistent
# with the rest of the analysis.
# TODO: we may decide to keep the apo (fep_lambda=1) state of
# the protein, in which case we need to handle it here.
fep_lambda, report_lambda = 1, 0
super().__init__(name, [
ParchTrajectoryForFepLambda(f"{name}_lambda0", fep_lambda,
report_lambda, *args)
])
[docs]class TrajectoryForSolubilityFep(Task):
"""
Task to generate the parched trajectory for the
lambda state with the fully-interacting molecule.
Results are all `Datum` objects:
- key = "ResultLambda1.ParchedCmsFname"
- val = Name of the parched CMS file: "lambda0-out.cms"
- key = "ResultLambda1.ParchedTrjFname"
- val = Name of the parched trajectory file::
"lambda0{ext}", where "{ext}" is the same extension of the input
trajectory file name.
"""
[docs] def __init__(self, name, num_solvent=200):
cms_fname_pattern = "{jobname}_replica{index}-out.cms"
trj_fname_pattern = "{jobname}_replica{index}"
out_bname_pattern = "lambda1"
args = [
cms_fname_pattern, trj_fname_pattern, out_bname_pattern, num_solvent
]
super().__init__(
name, [ParchTrajectoryForSolubilityFep(f"{name}_lambda1", *args)])
[docs]def execute(arkdb: ArkDb, tasks: Iterable[Task]) -> bool:
"""
Executes one or more tasks against the given database `arkdb`.
This function is guaranteed to do the following:
1. This function will examine each task's premises against the database.
2. If the premises are NOT met, it skips the task; otherwise, it will
proceed to check the task's options against the database.
3. After getting the premises and options data, it will call the task's
`execute` callable object. If the execution of the task is completed
without errors, it will set the task's `is_completed` attribute to true.
4. During the above steps, errors (if any) will be logged in the task's
`log` list.
5. After doing the above for all tasks, this function will return `True` if
all tasks are completed without errors, or `False` otherwise.
"""
for ta in tasks:
ta.clear()
kwargs = {}
for arg_name, dat in ta.premises[0]:
try:
dat.get_from(arkdb)
except ArkDbGetError as e:
ta.errlog.append(f"Premise '{dat.key}' failed: {e}")
kwargs[arg_name] = dat.val
if not ta.errlog:
# Preimses are met.
for arg_name, dat in ta.options[0]:
try:
dat.get_from(arkdb)
except ArkDbGetError as e:
ta.errlog.append(f"Option '{dat.key}' failed: {e}")
kwargs[arg_name] = dat.val
try:
ta.execute(arkdb, **kwargs)
except SubtaskExecutionError as e:
ta.errlog.insert(0, f"{e}")
except Exception as e:
ta.errlog.append("Task execution failed:\n%s\n%s" %
(e, traceback.format_exc()))
else:
for r in ta.results:
if isinstance(r, Datum):
r.put_to(arkdb)
ta.is_completed = True
return all(ta.is_completed for ta in tasks)
[docs]def collect_logs(tasks: Iterable[Task]) -> List[str]:
r"""
Iterates over the given `Task` objects, and aggregates the logs of
uncompleted tasks into a list to return.
The returned strings can be joined and printed out::
print("\n".join(collect_logs(...)))
and the text will look like the following::
task0: Task
message
another message
another multiword message
task1: ConcreteTaskForTesting
message
another arbitrary message
another completely arbitrary message
Note that the above is just an example to demostrate the format as explained
further below. Do NOT take the error messages literally. And all the error
messages here are unrelated to each other, and any patterns you might see is
unintended!
So for each uncompleted task, the name and the class' name of the task will
be printed out, and following that are the error messages of the task, each
in a separate line indented by 2 spaces.
Note the purpose of returning a list of strings instead of a single string
is to make it slightly easier to further indent the text. For example, if
you want to indent the whole text by two spaces. You can do this::
print(" %s" % "\n ".join(collect_logs(...)))
which will look like the following::
task0: Task
message
another message
another multiword message
task1: ConcreteTaskForTesting
message
another arbitrary message
another completely arbitrary message
"""
logs = []
for ta in tasks:
if not ta.is_completed:
logs.append("%s: %s" % (ta.name, type(ta).__name__))
logs.extend(" " + e for e in ta.errlog)
return logs
[docs]class Premise(Datum):
"""
A premise here is a datum that must be available for a task (see the
definition below) to be successfully executed.
"""
[docs] def __init__(self, key):
super().__init__(key)
[docs]class Option(Datum):
"""
An option here is a datum that does NOT have to be available for a task
(see the definition below) to be successfully executed.
"""
pass