Source code for schrodinger.application.desmond.stage.workflow

"""
Various multisim concrete stage classes.

Copyright Schrodinger, LLC. All rights reserved.

"""
import copy
import glob
import os
from past.utils import old_div
from pathlib import Path

from schrodinger import structure
from schrodinger.application.desmond import cmj
from schrodinger.application.desmond import cms
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import struc
from schrodinger.application.desmond import systype
from schrodinger.application.desmond import util
from schrodinger.application.desmond.picklejar import Picklable
from schrodinger.application.desmond.picklejar import PickleState
from schrodinger.utils import sea

from .jobs import DesmondJob

__all__ = ['Primer', 'Concluder', 'Task', 'Extern', 'Trim', 'Stop']


[docs]class Primer(cmj.StageBase): """ """ NAME = "primer"
[docs] def __init__(self, *arg, **kwarg): """ """ cmj.StageBase.__init__(self, *arg, **kwarg)
[docs] def describe(self): """ """
[docs] def start(self, fname_in): """ """ # Creates job whose output file is the original input structure file. job = cmj.Job(None, None, self, None, None, None, None) job.output.add(fname_in) job.status.set(cmj.JobStatus.SUCCESS) self.add_job(job) self._NEXT_STAGE.push(None)
[docs]class Concluder(cmj.StageBase): """ This stage does not create any jobs for the 'cmj.QUEUE'. """ NAME = "concluder"
[docs] def __init__(self, fname_out=None, *arg, **kwarg): super().__init__(*arg, **kwarg) self._fname_out = fname_out
[docs] def describe(self): # Empty to prevent this stage from showing up in log file pass
[docs] def release(self): if (self._fname_out is None): # Finds the last non-'stop' unskipped stage. stage = self._PREV_STAGE while (stage and (stage.NAME == "stop" or stage.param.should_skip.val)): stage = stage._PREV_STAGE self._fname_out = stage.param.struct_output.val if ( stage) else self._fname_out self._print("debug", "Final structure file: %s" % self._fname_out) if (self._fname_out != ""): fname_out = os.path.join(cmj.ENGINE.base_dir, self._fname_out).strip() output = self.get_prejobs()[0].output.struct_file() if (output is not None): output = output.strip() if (os.path.isfile(output)): if (fname_out != output): with structure.StructureReader(output) as r: with structure.StructureWriter(fname_out) as w: w.extend(r) # When using an absolute path, the file # is copied above. Otherwise, use jobcontrol # to transfer the file to the launch directory. if not Path(self._fname_out).is_absolute(): cmj.ENGINE.JOBBE.addOutputFile(self._fname_out) cmj.ENGINE.JOBBE.setStructureOutputFile(self._fname_out)
[docs]class Task(cmj.StageBase): """ """ NAME = "task" gid = 0, Picklable PARAM = cmj._create_param_when_needed(""" DATA = { task = "generic" set_family = {} } VALIDATE = { task = {type = enum range = ["desmond:regular" "desmond:fep" "desmond:afep" "desmond:auto" "desmond:raw" "mcpro:auto" "mcpro:fep" "generic" "watermap"]} set_family = [ {_mapcheck = check_stage_family _skip = all} ] } """)
[docs] def __init__(self, *arg, **kwarg): """ """ cmj.StageBase.__init__(self, *arg, **kwarg) self._systype = None self._systrait = [] self._setfamily = 0 # Can take these values: 0 = not done yet; 1 = well done; 2 = done but got error.
# __init__ def __getstate__(self, state=None): """ """ state = state if (state) else PickleState() state._systype = self._systype state._systrait = self._systrait # - Do not pickle this attribute because we will need to rerun the `set_family' method when the job is restarted from # the checkpoint file. # state._setfamily = self._setfamily return cmj.StageBase.__getstate__(self, state) def _detect_systype(self): """ """ self._log("Detecting system type...") is_fine = True for pj in self.get_prejobs(): output = pj.output.struct_file() try: is_compatible, syst, syst_str = self._systype.check_system( output) self._log(" It seems " + syst_str) pj.systype = syst if (not is_compatible): self._log(" It cannot be used for task \"%s\"" % self.param.task.val) is_fine = False except systype.DetectionError as e: self._print("quiet", str(e)) is_fine = False if (is_fine): self._log("Recognizing traits of system...") for pj in self.get_prejobs(): self._systrait = self._systype.get_trait( pj.output.struct_file()) if (self._systrait): s = "" for e in self._systrait: s += str(e) + " " self._log(" " + s) else: self._log(" (none)") return is_fine def _set_permanency(self): """ Sets `pj.permanent_restrain` and `pj.permanent_group` Note that `pj.systype` should be already set before call to this method. """ typer = self._systype.typer for pj in self.get_prejobs(): if (typer.has_rule("mature") and typer.check_rule(pj.systype, "mature")): if (self._PREV_STAGE.NAME == "primer"): model = cms.Cms(file=pj.output.struct_file()) # when empty set permanent_restrain to None pj.permanent_restrain = model.get_restrain() or None pj.permanent_group = model.get_atom_group() if (pj.permanent_group == []): pj.permanent_group = None else: pj.permanent_restrain = None pj.permanent_group = None def _newjob(self, task, pj, gid=None): """ """ if (gid is None): Task.gid += 1 gid = Task.gid if (isinstance(pj, DesmondJob)): job = copy.deepcopy(pj) job.task = task job.gid = gid else: job = DesmondJob(task, gid, pj.permanent_restrain, pj.permanent_group, "task", pj, self, None, None) job.output = copy.deepcopy(pj.output) return job def _create_job(self): """ """ self._set_permanency() param = self.param typer = self._systype.typer i = 1 if (typer == systype.DesmondTyper): task = param.task.val[8:] for pj in self.get_prejobs(): if (task in [ "auto", "afep", "fep", ] and typer.check_rule(pj.systype, "afep")): self.add_job(self._newjob("afep", pj)) elif (task in [ "auto", "fep", ] and pj.systype == systype.DesmondTyper.MATURE_FEP): self.add_job(self._newjob("fep", pj)) elif (task in [ "auto", "fep", ] and pj.systype == systype.DesmondTyper.RAW_FEP): env_ct, ref_ct, mut_ct = cms.decomp_rawfep_structure( struc.read_all_ct(pj.output.struct_file())) Task.gid += 1 for st in mut_ct: job = self._newjob("fep", pj, Task.gid) tmp_job_prefix = cmj.ENGINE.jobname + "_fep" + str(i) if (job.prefix is None): job.prefix = os.path.join(param.prefix.val, tmp_job_prefix) else: job.prefix = tmp_job_prefix dir = os.path.join(cmj.ENGINE.base_dir, job.prefix) if (not os.path.isdir(dir)): os.makedirs(dir) util.chdir(dir) out_fname = os.path.abspath(cmj.ENGINE.jobname + "_" + "fep" + str(i) + ".mae") job.output = cmj.JobOutput() job.output.add(out_fname) util.write_n_ct(out_fname, env_ct + ref_ct + [ st, ]) self.add_job(job) i += 1 else: self.add_job(self._newjob("regular", pj)) elif (typer == systype.WatermapTyper): for pj in self.get_prejobs(): output = pj.output.struct_file() solute_fname = output + "_solute.mae" new_job = self._newjob("regular", pj) new_job.output.set_struct_file(solute_fname) self.add_job(new_job) elif (typer == systype.McproTyper): for pj in self.get_prejobs(): env_ct, ref_ct, mut_ct = cms.decomp_rawfep_structure( struc.read_all_ct(pj.output.struct_file())) util.chdir(cmj.ENGINE.base_dir) prot_mae = os.path.abspath(cmj.ENGINE.jobname + "-in-prot.mae") pert_pair = set() frag_fname = {} i = 1 util.write_n_ct(prot_mae, env_ct) for ct in mut_ct: this_ = ct.property[constants.FEP_FRAGNAME] that_ = ct.property[constants.FEP_WRTFRAG].split("||") fname = os.path.abspath(cmj.ENGINE.jobname + "-in-lig" + str(i) + ".mae") i += 1 struc.delete_structure_properties(ct, constants.FEP_WRTFRAG) util.write_n_ct(fname, [ct]) frag_fname[this_] = fname for t in that_: if (this_ > t): a = ( t, this_, ) else: a = ( this_, t, ) pert_pair.add(a) i = 1 for pair in pert_pair: job = cmj.Job("", pj, self, None, None, prefix=cmj.ENGINE.jobname + "_fep" + str(i)) i += 1 job.output.add(prot_mae, tag="prot") job.output.add(frag_fname[pair[0]], tag="ligs") job.output.add(frag_fname[pair[1]], tag="lige") self.add_job(job) self._NEXT_STAGE._pertdb_type = ref_ct[0].property[ constants.FEP_PERTDB] else: for pj in self.get_prejobs(): pj.stage = cmj.weakref.proxy(self) self.add_job(pj) for e in self.jobs: e.status.set(cmj.JobStatus.SUCCESS)
[docs] def set_family(self): """ """ typer = self._systype.typer if (typer == systype.DesmondTyper): task = self.param.task.val[8:] for pj in self.get_prejobs(): if task in ["auto", "afep", "fep"] and typer.check_rule( pj.systype, "afep"): self.param.set_family["desmond"] = sea.Map( 'coulomb_method = pme') self.param.set_family.desmond.add_tag("setbyuser") for pj in self.get_prejobs(): output = pj.output.struct_file() self._systype.prepare(self, output) if (self._setfamily == 0): global STAGE_FAMILY if (STAGE_FAMILY is None): STAGE_FAMILY = all_stage_family() affected_stage = set() family_names = list(self.param.set_family) family = [(STAGE_FAMILY[fn], fn) for fn in family_names ] # See below for `STAGE_FAMILY'. # - Now `family' is list of pairs. The first element of each pair is a set that contains all stage type names, # the 2nd one the family name. # - We need to sort `family' in the descending order upon the first element of each pair. family.sort(key=lambda x: x[0]) family.reverse() for family_stages, family_name in family: stage = self while stage: if stage.NAME in family_stages: self._update_family_param(stage.param, family_name) affected_stage.add(stage) # also check for sub-stages on concatenate stages if stage.NAME == 'concatenate': for stage_name in family_stages: if stage_name in stage.param: self._update_family_param( stage.param, family_name) substage_params = stage.param[stage_name] for substage_param in substage_params: self._update_family_param( substage_param, family_name) affected_stage.add(stage) stage = stage._NEXT_STAGE error = "" for stage in affected_stage: ev = stage.check_param() if (ev._err != ""): error += "Value error(s) for stage[%d]:\n%s\n" % ( stage._INDEX, ev._err, ) if ev.unchecked_map: error += "Unrecognized parameters for stage[%d]: %s\n\n" % ( stage._INDEX, ev.unchecked_map, ) if (error == ""): self._setfamily = 1 else: self._setfamily = 2 print(error)
def _update_family_param(self, param, family_name): """ Update a param with this stage's set_family param for a given family name. """ setbyuser_map = sea.sea_filter(param, "setbyuser") param.update([self.param.set_family[family_name], setbyuser_map], tag="setbyuser")
[docs] def crunch(self): self._systype = systype.SysType(self.param.task.val) if (self._detect_systype()): self.set_family() if (self._setfamily != 2): self._create_job()
[docs]class Extern(cmj.StageBase): """ """ NAME = "extern" PARAM = cmj._create_param_when_needed([ """ DATA = { auxiliary_file = "" command = "" command_once = "" backend = ? } VALIDATE = { auxiliary_file = [ {type = str range = [0 10000000000]} {type = list size = 0 elem = {type = str range = [0 10000000000]} } ] command = {type = str range = [0 10000000000]} command_once = {type = str range = [0 10000000000]} backend = [{_skip = all} {type = none}] } """, ])
[docs] def __init__(self, *arg, **kwarg): """ """ cmj.StageBase.__init__(self, should_pack=True, *arg, **kwarg) self._is_command_once_called = False self._is_command_imported = False self._serialized_attribute = set([ "_serialized_attribute", "_is_command_once_called", ])
# __init__ def __getstate__(self, state=None): """ """ state = state if (state) else PickleState() for e in self._serialized_attribute: state.__dict__[e] = self.__dict__[e] return cmj.StageBase.__getstate__(self, state)
[docs] def serialize(self, attribute): """ """ attribute = attribute if (isinstance(attribute, list)) else [attribute] self._serialized_attribute.update(attribute)
[docs] def crunch(self): """ """ if (cmj.ENGINE.base_dir not in cmj.sys.path): cmj.sys.path.append(cmj.ENGINE.base_dir) for pj in self.get_prejobs(): self.add_job(pj) if ("" != self.param.command_once.val and not self._is_command_once_called): util.chdir(cmj.ENGINE.base_dir) fname = "multisim_stage_%d_command_once.py" % self._ID fh = open(fname, "w") fh.write(self.param.command_once.val) fh.close() command_once = __import__(fname[:-3]) try: main = command_once.main except AttributeError: pass else: main(self) self._is_command_once_called = True if ("" != self.param.command.val): if (not self._is_command_imported): util.chdir(cmj.ENGINE.base_dir) fname = "multisim_stage_%d_command.py" % self._ID fh = open(fname, "w") fh.write(self.param.command.val) fh.close() self._command = __import__(fname[:-3]) self._is_command_imported = True try: main = self._command.main except AttributeError: pass else: main(self, pj)
[docs]class Trim(cmj.StageBase): """ """ NAME = "trim" PARAM = cmj._create_param_when_needed([ """ DATA = { save = "all" erase = [] erase2 = [] } VALIDATE = { save = [ {type = enum range = ["all"]} {type = list size = 0 elem = {type = int}} ] erase = { type = list size = 0 elem = {type = list size = 2 elem = [{type = int} {type = str1}]} } erase2 = { type = list size = 0 elem = {type = list size = 2 elem = [{type = int} {type = str1}]} } } """, ])
[docs] def __init__(self, *arg, **kwarg): """ """ cmj.StageBase.__init__(self, should_pack=False, *arg, **kwarg)
# __init__
[docs] def crunch(self): """ """ util.chdir(cmj.ENGINE.base_dir) save = self.param.save.val if (isinstance(save, list)): normalized_save = [] for index in save: if (index < 0): normalized_save.append(self._INDEX + index) else: normalized_save.append(index) for stage in cmj.ENGINE.stage[1:self._INDEX]: if (stage._INDEX not in normalized_save): fname = stage.param.compress.val raw_val = stage.param.compress.raw_val if (raw_val == fname): # No effects if the file/dir does not exist. util.remove_file(fname) fnames_to_delete = set() for e in self.param.erase + self.param.erase2: stg_index = e[0].val fname = e[1] if stg_index == 0: fnames_to_delete.add(fname.val) else: stg = cmj.ENGINE.stage[stg_index] if stg_index > 0 else self while (stg_index < 0): stg = stg._PREV_STAGE stg_index += 1 # Defines the macros such as `$MAINJOBNAME` for `stg`. # N.B.: This mutates the global variable `sea.macro_dict`. for pj in self.get_prejobs(): stg._get_jobname_and_dir(pj) fnames_to_delete.add(fname.val) for fname in fnames_to_delete: found_fnames = glob.glob(fname) self._log(f"Tried deleting '{fname}'") if found_fnames: self._log(f" Found and deleted files: {found_fnames}") for e in found_fnames: # No effects if the file/dir does not exist. util.remove_file(e) else: self._log(" Files not found") self.add_jobs(self.get_prejobs())
[docs]class Stop(cmj.StageBase): """ """ NAME = "stop"
[docs] def __init__(self, *arg, **kwarg): """ """ cmj.StageBase.__init__(self, should_pack=False, *arg, **kwarg)
# __init__
[docs] def crunch(self): """ """ self._print("debug", "In Stop.crunch") self.add_jobs(self.get_prejobs()) self._print("debug", "Out Stop.crunch")
[docs] def prestage(self): """ """ if (not self.param.should_skip.val): stage = self._NEXT_STAGE while (stage.NAME != "concluder"): stage = stage._NEXT_STAGE self._NEXT_STAGE = stage stage._PREV_STAGE = self cmj.StageBase.prestage(self)
[docs] def poststage(self): """ """ self._log("Stop the workflow now. Skipping subsequent stages...")
# Stage family stuff PREDEFINED_STAGE_FAMILY = { "remd": set([ "replica_exchange", "lambda_hopping", ]), "md": set([ "simulate", "replica_exchange", "lambda_hopping", "concatenate", "fep_vrun" ]), "desmond": set([ "minimize", "simulate", "replica_exchange", "lambda_hopping", "fep_vrun", "concatenate", ]), }
[docs]def all_stage_family(): """ """ family = copy.deepcopy(PREDEFINED_STAGE_FAMILY) family["generic"] = set() for stage_name in cmj.StageBase.stage_cls: if (stage_name != "generic"): family[stage_name] = set([ stage_name, ]) family["generic"].add(stage_name) return family
# The following functions are used to check settings in *.msj files STAGE_FAMILY = None def _xchk_check_stage_family(map, valid, ev, prefix): """ """ global STAGE_FAMILY if (STAGE_FAMILY is None): STAGE_FAMILY = all_stage_family() family = list(map) has_err = False err_str = "Unrecognized stage family: " for f in family: if (f not in STAGE_FAMILY): has_err = True err_str += f + ", " if (has_err): err_str = err_str[0:-2] ev.record_error(prefix, err_str) def _xchk_check_box_size(param, valid, ev, prefix): """ """ if (param.shape.val == "orthorhombic"): if (not isinstance(param.size, list) or 3 != len(param.size)): ev.record_error( prefix, "box.size must be a list of 3 real numbers when box shape is 'orthorhombic'" ) return elif (param.shape.val == "triclinic"): if (not isinstance(param.size, list) or 6 != len(param.size)): ev.record_error( prefix, "box.size must be a list of 6 real numbers when box shape is 'triclinic'" ) return else: if (not isinstance(param.size, sea.Atom)): ev.record_error( prefix, "box.size must be a real number for '%s' box shape." % param.shape.val) return def _xchk_check_restrain(map, valid, ev, prefix): """ """ if ("generator" in map): if (map.generator.val in ["abfe_cross_link"]): try: n = len(map.fc) except AttributeError: try: n = len(map.force_constant) except AttributeError: ev.record_error( prefix, "Force constants for abfe_cross_link not found") return if (n != 3): ev.record_error( prefix, "Force constants for abfe_cross_link must be a triplet of numbers" ) return elif (map.generator.val in ["alpha_helix"]): n = None try: n = len(map.fc) except AttributeError: try: n = len(map.force_constant) except AttributeError: pass if (n is not None and n != 3): ev.record_error( prefix, "Force constants for alpha_helix, if set, must be a triplet of numbers" ) try: n = len(map.sigma) if (n != 3): ev.record_error( prefix, "Sigma for alpha_helix, if set, must be a triplet of numbers" ) except AttributeError: pass try: n = len(map.ref) if (n != 3): ev.record_error( prefix, "Reference value for alpha_helix, if set, must be a triplet" ) except AttributeError: pass return # position harmonic: 1 atom, 3 k, 3 ref, sigma = None # position fbhw : 1 atom, 1 k, 3 ref, 1 sigma # stretch fbhw : 2 atom, 1 k, 1 ref, 1 sigma # angle fbhw : 3 atom, 1 k, 1 ref, 1 sigma # torsion fbhw : 4 atom, 1 k, 1 ref, 1 sigma # NOE : 2 atom, 1 k, 2 ref, 2 sigma if ("atom" not in map): ev.record_error(prefix, "No atom specified for restraining") return if ("force_constant" not in map and "fc" not in map): ev.record_error(prefix, "No force constant specified for restraint") return ref = map.ref if ("ref" in map) else (map.reference_position if ("reference_position" in map) else None) fc = map.fc if ("fc" in map) else map.force_constant num_atom = 1 if (isinstance(map.atom, sea.Atom)) else len(map.atom) if (num_atom == 1): if ("sigma" in map): # position fbhw. if (isinstance(map.sigma, sea.Atom)): if (not isinstance(map.force_constant, sea.Atom)): ev.record_error( prefix, "Wrong force_constant setting for position flat-bottom-harmonic restraint" ) return if (ref): is_wrong_setting = False if (isinstance(ref, sea.Atom)): if (ref.val not in [ "retain", "reset", ]): is_wrong_setting = True elif (isinstance(ref, sea.List)): num_elem = len(ref) if (num_elem > int(old_div(num_elem, 3)) * 3): is_wrong_setting = True else: is_wrong_setting = True if (is_wrong_setting): ev.record_error( prefix, "Value of ref for flat-bottom-harmonic restraint should be either 'retain', " "or 'reset', or a list of cartesian coordinates") return else: ev.record_error( prefix, "Wrong sigma setting for position flat-bottom-harmonic restraint" ) return else: # position harmonic if ((isinstance(fc, sea.List) and len(fc) != 3) or isinstance(fc, sea.Map)): ev.record_error( prefix, "Wrong force_constant setting for position harmonic restraint" ) return if (ref): is_wrong_setting = False if (isinstance(ref, sea.Atom)): if (ref.val not in [ "retain", "reset", ]): is_wrong_setting = True elif (isinstance(ref, sea.List)): num_elem = len(ref) if (num_elem > int(old_div(num_elem, 3)) * 3): is_wrong_setting = True else: is_wrong_setting = True if (is_wrong_setting): ev.record_error( prefix, "Value of ref for harmonic restraint should be either 'retain', " "or 'reset', or a list of cartesian coordinates") return else: if (num_atom > 4): ev.record_error(prefix, "Wrong atom setting for internal restraints") return def _xchk_multisim_file_exists(map, valid, ev, prefix): """ """ # Finds the 'should_skip' parameter. parent = map.parent() while (parent): if ("should_skip" in parent): break parent = parent.parent() if not parent.should_skip.val: val = map.val if (val != "" and not os.path.isfile(val)): ev.record_error( prefix, "File not found: {0}, curdir = {1}, dir = {2}".format( val, os.path.abspath(os.curdir), os.listdir(os.curdir))) def _xchk_check_pose_conf_restraint(map, valid, ev, prefix): """ check if both pose_conf_restraint and fep_enchance_sampling_diheral are enabled """ parent = map.parent() if "fep_enhance_sampling_dihedral" in parent: if map.enable.val and parent.fep_enhance_sampling_dihedral.val: ev.record_error( prefix, "fep_enhance_sampling_dihedral and pose_conf_restraint cannot be enabled at the same time." ) sea.reg_xcheck("check_stage_family", _xchk_check_stage_family) sea.reg_xcheck("check_restrain", _xchk_check_restrain) sea.reg_xcheck("check_box_size", _xchk_check_box_size) sea.reg_xcheck("multisim_file_exists", _xchk_multisim_file_exists) sea.reg_xcheck("check_pose_conf_restraint", _xchk_check_pose_conf_restraint) # The functions are used by multisim for "effect_if".
[docs]def systrait_is(stage, map_msj, arg): """ """ # Finds the latest 'task' stage. ret_val = False task_stg = stage while (not isinstance(task_stg, Task) and task_stg is not None): task_stg = task_stg._PREV_STAGE if (task_stg): ret_val = True for trait in arg: ret_val &= (trait in task_stg._systrait) return ret_val
[docs]def systype_is(stage, map_msj, arg): """ """ # Finds the latest 'task' stage. ret_val = False task_stg = stage while (not isinstance(task_stg, Task) and task_stg is not None): task_stg = task_stg._PREV_STAGE if (task_stg): ret_val = (arg == task_stg._systype.rule) return ret_val
[docs]def has_file(stage, map_msj, arg): """ """ cwd = os.getcwd() util.chdir(cmj.ENGINE.base_dir) result = True for fname in arg: result &= os.path.isfile(fname) util.chdir(cwd) return result
[docs]def has_dir(stage, map_msj, arg): """ """ cwd = os.getcwd() util.chdir(cmj.ENGINE.base_dir) result = True for fname in arg: result &= os.path.isdir(fname) util.chdir(cwd) return result
[docs]def is_debugging(stage, map_msj, arg): """ """ return (cmj.GENERAL_LOGLEVEL == "debug")
cmj.reg_checking("systrait_is", systrait_is) cmj.reg_checking("has_file", has_file) cmj.reg_checking("has_dir", has_dir) cmj.reg_checking("is_debugging", is_debugging) if ("__main__" == __name__): stage = cmj.StageBase() job0 = DesmondJob("", None, None, None, None, None, stage, None, None)