Source code for schrodinger.application.desmond.cmj

"""
This module provides fundamental facilities for writing a multisim driver
script, for writing multisim concrete stage classes, and for dealing with
protocol files.

Copyright Schrodinger, LLC. All rights reserved.

"""

import copy
import glob
import os
import pickle
import shutil
import signal
import subprocess
import sys
import tarfile
import threading
import time
import weakref
from io import BytesIO
from typing import BinaryIO
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
from pathlib import Path

import schrodinger.application.desmond.bld_ver as bld
import schrodinger.application.desmond.cmdline as cmdline
import schrodinger.application.desmond.envir as envir
import schrodinger.application.desmond.picklejar as picklejar
import schrodinger.application.desmond.util as util
import schrodinger.infra.mm as mm
import schrodinger.job.jobcontrol as jobcontrol
import schrodinger.utils.sea as sea
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import queue
from schrodinger.utils import fileutils

from .picklejar import Picklable
from .picklejar import PicklableMetaClass
from .picklejar import PickleJar

# Contributors: Yujie Wu

# Info
VERSION = "4.0.0"
BUILD = bld.desmond_build_version()

# Machinery
QUEUE = None
ENGINE = None

# Log
LOGLEVEL = [
    "silent",
    "quiet",
    "verbose",
    "debug",
]
GENERAL_LOGLEVEL = "quiet"

# Suffixes
PACKAGE_SUFFIX = ".tgz"
CHECKPOINT_SUFFIX = "-multisim_checkpoint"

# Filenames
CHECKPOINT_FNAME = "$MAINJOBNAME" + CHECKPOINT_SUFFIX

_PRODUCTION_SIMULATION_STAGES = ["lambda_hopping", "replica_exchange"]


def _print(loglevel, msg):
    if LOGLEVEL.index(loglevel) <= LOGLEVEL.index(GENERAL_LOGLEVEL):
        if loglevel == "debug":
            print("MSJDEBUG: %s" % msg)
        else:
            print(msg)

        sys.stdout.flush()

















def _time_str_to_time(time_str, scale=1.0):
    h, m, s = [e[:-1] for e in time_str.split()]
    return scale * (float(h) * 3600 + float(m) * 60 + float(s))


def _time_to_time_str(inp_time):
    h, r = divmod(int(inp_time), 3600)
    m, s = divmod(r, 60)
    return "%sh %s' %s\"" % (h, m, s)


[docs]class JobStatus(object): # Good status WAITING = 101 RUNNING = 102 SUCCESS = 103 # Bad status and non-retriable BACKEND_ERROR = 201 PERMANENT_LICENSE_FAILURE = 202 NON_RETRIABLE_FAILURE = 299 # Bad status and retriable TEMPORARY_LICENSE_FAILURE = 301 KILLED = 302 FIZZLED = 303 LAUNCH_FAILURE = 304 FILE_NOT_FOUND = 305 FILE_CORRUPT = 306 STRANDED = 307 CHECKPOINT_REQUESTED = 308 CHECKPOINT_WITH_RESTART_REQUESTED = 309 RETRIABLE_FAILURE = 399 STRING = { WAITING: "is waiting for launching", RUNNING: "is running", SUCCESS: "was successfully finished", PERMANENT_LICENSE_FAILURE: ("could not run due to permanent license " "failure"), TEMPORARY_LICENSE_FAILURE: "died due to temporary license failure", KILLED: "was killed", FIZZLED: "fizzled", STRANDED: "was stranded", LAUNCH_FAILURE: "failed to launch", FILE_NOT_FOUND: ("was finished, but registered output files were not " "found"), FILE_CORRUPT: ("was finished, but an essential output file was found " "corrupt"), BACKEND_ERROR: "died due to backend error", RETRIABLE_FAILURE: "died on unknown retriable failure", NON_RETRIABLE_FAILURE: "died on unknown non-retriable failure", CHECKPOINT_REQUESTED: "user requested job be checkpointed", CHECKPOINT_WITH_RESTART_REQUESTED: "user requested job be checkpointed and restarted" }
[docs] def __init__(self, code=WAITING): self._code = code self._error = None
def __str__(self): s = "" try: s += JobStatus.STRING[self._code] except KeyError: if self._error is None: s += "unknown error" if self._error is not None: s += "\n" + self._error return s def __eq__(self, other): if isinstance(other, JobStatus): return self._code == other._code else: try: return self._code == int(other) except ValueError: raise NotImplementedError def __ne__(self, other): return not self.__eq__(other)
[docs] def set(self, code, error=None): if isinstance(code, JobStatus): self._code = code else: try: self._code = int(code) except ValueError: raise NotImplementedError self._error = error
[docs] def is_good(self): return self._code < 200
[docs] def is_retriable(self): return self._code > 300
[docs] def should_restart_from_checkpoint(self): return self._code == self.CHECKPOINT_WITH_RESTART_REQUESTED
[docs]class JobOutput(object):
[docs] def __init__(self): # Key: file name. Value: None or a callable that checks the file self._file = {} self._type = {} # Key: file name. Value: "file" | "dir" self._tag = {} # Key: tag. Value: file name self._struct = None
# Note on pickling: Values in `self._file' will be set to None when # `self' is pickled.
[docs] def __len__(self): """ Returns the number of registered output files. """ return len(self._file)
def __iter__(self): """ Iterates through the registered output files. Note that the order of the files here are not necessarily the same order of file registration. """ for f in self._file: yield f def __list__(self): return list(self._files) def __deepcopy__(self, memo={}): # noqa: M511 newobj = JobOutput() memo[id(self)] = newobj newobj._file = copy.deepcopy(self._file) newobj._type = copy.deepcopy(self._type) newobj._tag = copy.deepcopy(self._tag) return newobj def __getstate__(self): tmp_dict = copy.copy(self.__dict__) _file = tmp_dict["_file"] for k in _file: _file[k] = None return tmp_dict
[docs] def update_basedir(self, old_basedir, new_basedir): old_basedir += os.sep new_basedir += os.sep new_file = {} for k in self._file: v = self._file[k] if k.startswith(old_basedir): k = k.replace(old_basedir, new_basedir) new_file[k] = v self._file = new_file new_type = {} for k in self._type: v = self._type[k] if k.startswith(old_basedir): k = k.replace(old_basedir, new_basedir) new_type[k] = v self._type = new_type for k in self._tag: v = self._tag[k] if v.startswith(old_basedir): v = v.replace(old_basedir, new_basedir) self._tag[k] = v try: if self._struct and self._struct.startswith(old_basedir): self._struct = self._struct.replace(old_basedir, new_basedir) except AttributeError: pass try: new_cms = [] for e in self.cms: new_cms.append(e.replace(old_basedir, new_basedir)) self.cms = new_cms except AttributeError: pass
[docs] def add(self, filename, checker=None, tag=None, type="file"): """ :param type: either "file" and "dir". """ if filename: if type not in ("file", "dir"): raise ValueError( 'Valid values for \'type\' are "file" and "dir". ' f'But "{type}" is given') self._file[filename] = checker self._type[filename] = type if tag is not None: if tag in self._tag: old_filename = self._tag[tag] del self._file[old_filename] del self._type[old_filename] self._tag[tag] = filename
[docs] def remove(self, filename): """ """ try: del self._file[filename] except KeyError: pass try: del self._type[filename] except KeyError: pass for key, value in self._tag.items(): if value == filename: del self._tag[key] break
[docs] def get(self, tag): return self._tag.get(tag)
[docs] def check(self, status): for fname in self._file: _print("debug", "checking output file: %s" % fname) if self._type[fname] == "file": if os.path.isfile(fname): checker = self._file[fname] if checker: err_msg = checker(fname) if err_msg: status.set(JobStatus.FILE_CORRUPT, err_msg) return else: _print("debug", "Output file: %s not found" % fname) try: _print( "debug", "Files in current directory: %s" % str(os.listdir(os.path.dirname(fname)))) except OSError: _print( "debug", "Directory not found: %s" % os.path.dirname(fname)) status.set(JobStatus.FILE_NOT_FOUND) return elif self._type[fname] == "dir": if not os.path.isdir(fname): _print("debug", "Output directory: %s not found" % fname) try: _print( "debug", "Files in parent directory: %s" % str(os.listdir(os.path.dirname(fname)))) except OSError: _print( "debug", "Directory not found: %s" % os.path.dirname(fname)) status.set(JobStatus.FILE_NOT_FOUND) return status.set(JobStatus.SUCCESS)
[docs] def set_struct_file(self, fname): self._struct = fname if fname not in self._file: self.add(fname)
[docs] def struct_file(self): if not self._struct: for fname in self: if fname.endswith( (".mae", ".cms", ".maegz", ".cmsgz", ".mae.gz", ".cms.gz")): return fname else: return self._struct return None
[docs] def log_file(self): for fname in self: if fname.endswith(".log"): return fname return None
[docs]class JobInput(JobOutput): def __deepcopy__(self, memo={}): # noqa: M511 newobj = JobInput() memo[id(self)] = newobj newobj._file = copy.deepcopy(self._file) newobj._type = copy.deepcopy(self._type) newobj._tag = copy.deepcopy(self._tag) return newobj
[docs] def cfg_file(self): for fname in self: if fname.endswith(".cfg"): return fname return None
[docs] def incfg_file(self): for fname in self: if fname.endswith("in.cfg"): return fname return None
[docs] def outcfg_file(self): for fname in self: if fname.endswith("out.cfg"): return fname return None
[docs]class JobErrorHandler:
[docs] @staticmethod def default(job): """ If the job status is bad, attempt to print the log file and nvidia-smi output. """ if not job.status.is_good(): job._print( "quiet", "jlaunch_dir: %s\n" % job.dir + "jlaunch_cmd: %s" % subprocess.list2cmdline(job.jlaunch_cmd)) log_fname = job.output.log_file() if log_fname and os.path.exists(log_fname): job._print("quiet", "Log file : %s" % log_fname) with open(log_fname, "r") as f: log_content = f.readlines() job._print("quiet", "Log file content:\n%s" % ">".join(log_content)) job._print("quiet", "(end of log file)\n") else: job._print("quiet", "No log file registered for this job\n") # call nvidia-smi and print output to log file if job.USE_GPU: try: output = subprocess.check_output("nvidia-smi", universal_newlines=True) job._print("quiet", "nvidia-smi output:\n%s" % output) except (FileNotFoundError, subprocess.CalledProcessError): job._print("quiet", "No nvidia-smi output available\n")
[docs] @staticmethod def restart_for_backend_error(job): """ Run the default handler and if the status is killed or backend error, mark the failure as retriable. """ if not job.status.is_good(): JobErrorHandler.default(job) if job.status in [JobStatus.BACKEND_ERROR, JobStatus.KILLED]: job.status.set(JobStatus.RETRIABLE_FAILURE)
[docs]def exit_code_is_defined(job): """ Return True if job has an exit code. Failed jobs may not have exit codes if they are killed by the queueing system or otherwise untrackable. """ try: int(job.ExitCode) except ValueError: return False return True
[docs]class Job(object): # most jobs do not use gpu USE_GPU = False
[docs] class Time(object):
[docs] def __init__(self, launch, start, end, num_cpu, cpu_time, duration): self.launch = launch self.start = start self.end = end self.num_cpu = num_cpu self.cpu_time = cpu_time self.duration = duration
@staticmethod def _get_time_helper(jobtime): try: t = time.mktime(time.strptime(jobtime, jobcontrol.timestamp_format)) s = time.ctime(t) except AttributeError: t = None s = "(unknown)" return t, s
[docs] @staticmethod def get_time(jctrl, num_cpu): launch_time, str_launch_time = Job._get_time_helper(jctrl.LaunchTime) if jctrl.StartTime: start_time, str_start_time = Job._get_time_helper(jctrl.StartTime) else: return Job.Time(str_launch_time, "(not started)", "N/A", num_cpu, "N/A", "N/A") if jctrl.StopTime: stop_time, str_stop_time = Job._get_time_helper(jctrl.StopTime) else: return Job.Time(str_launch_time, start_time, "(stranded)", num_cpu, "N/A", "N/A") if start_time is not None and num_cpu != "(unknown)": cpu_time = util.time_duration(start_time, stop_time, num_cpu) duration = util.time_duration(start_time, stop_time) else: cpu_time = "(unknown)" duration = "(unknown)" return Job.Time(str_launch_time, str_start_time, str_stop_time, num_cpu, cpu_time, duration)
[docs] def get_proc_time(self): proc_time = Job.get_time(self.jctrl, self.num_cpu).cpu_time return _time_str_to_time(proc_time) if proc_time != "(unknown)" else 0.0
[docs] def __init__(self, jobname, parent, stage, jlaunch_cmd, dir, host_list=None, prefix=None, what=None, err_handler=JobErrorHandler.default, is_output=True): self.jobname = jobname self.tag = None # Job object from which this `Job' object was derived. self.parent = parent # other Job objects from which this `Job' object was derived. self.other_parent = None # Job control object, will be set once the job is launched. self.jctrl = None self.jlaunch_cmd = jlaunch_cmd # Job launch command # List of hosts where this job can be running self.host_list = host_list # Actual host where this job is running self.host = jobcontrol.Host("localhost") # By default, subjobs do not need a host other than localhost. self.need_host = False self.num_cpu = 1 self.use_hostcpu = False # Launch directory, also where the job's outputs will be copied back self.dir = dir self.prefix = prefix # Prefix directory of the launch directory self.what = what # A string that stores more specific job description self.output = JobOutput() # Output file names self.input = JobInput() # Input file names self.status = JobStatus() # Job status # `None' or a callable object that will be called to handle job errors. self.err_handler = err_handler self._jctrl_hist = [] self._has_run = False if self.parent and self.prefix is None: self.prefix = self.parent.prefix if isinstance(stage, weakref.ProxyType): self.stage = stage else: self.stage = weakref.proxy(stage) # Note on pickling: `self.err_handler' will not be picked. self.old = False # whether the job was run with the current instance of ENGINE # is_output is used to signal that this is a stage's output. Some # stages which implement `hook_captured_successful_job` should set # `is_output=False` on the intermediate jobs and then set # `is_output=True` on any final jobs it creates self.is_output = is_output
@property def is_for_jc(self) -> bool: """ Whether or not this job should be submitted to job control """ return self.is_launchable and isinstance(self.jlaunch_cmd, list) @property def is_launchable(self) -> bool: return bool(self.jlaunch_cmd) @property def failed(self) -> bool: return not self.status.is_good() @property def is_retriable(self) -> bool: # Retriable means it can be retried by the -RETRIES mechanism which is # different than just restarting a job when using -RESTART return self.status.is_retriable() @property def is_incomplete(self) -> bool: return self.failed or self.status in [ JobStatus.WAITING, JobStatus.RUNNING ] @property def is_restartable(self): return self.is_incomplete and self.is_launchable def __deepcopy__(self, memo={}): # noqa: M511 newobj = object.__new__(self.__class__) memo[id(self)] = newobj for k, v in self.__dict__.items(): if k in ["stage", "jctrl", "parent"]: value = self.__dict__[k] elif k == "other_parent": value = copy.copy(self.other_parent) elif k == "_jctrl_hist": value = [] else: value = copy.deepcopy(v, memo) setattr(newobj, k, value) return newobj def __getstate__(self, state=None): state = state if (state) else copy.copy(self.__dict__) if "err_handler" in state: del state["err_handler"] if "jctrl" in state: state["jctrl"] = str(self.jctrl) if "_jctrl_hist" in state: state["_jctrl_hist"] = ["removed_in_serialization"] if "jlaunch_cmd" in state: if callable(state["jlaunch_cmd"]): state["jlaunch_cmd"] = "removed_in_serialization" if "stage" in state: state["stage"] = (self.stage if (isinstance(self.stage, int)) else self.stage._INDEX) return state def __setstate__(self, state): self.__dict__.update(state) if "stage" in state and ENGINE: self.stage = weakref.proxy(ENGINE.stage[self.stage]) def __repr__(self): """ Returns the jobname string in the format: <jobname>. """ r = f"<{self.jobname}" if self.jctrl: r += f"({self.jctrl})" r += f" status: {self.status}" # Temporarily for testing if isinstance(self.stage, StageBase): r += f" stage: {self.stage.NAME}" r += f" is_output: {self.is_output}" r += f" old: {self.old} is_for_jc: {self.is_for_jc}" r += ">" return r def _print(self, loglevel, msg): """ The internal print function of this job. Printing is at the same 'loglevel' as self.stage. """ self.stage._print(loglevel, msg) def _log(self, msg): """ The internal log function of this job. """ self.stage._log(msg) def _host_str(self): """ Returns a string representing the hosts. """ if self.jlaunch_cmd: if '-HOST' in self.jlaunch_cmd: return self.jlaunch_cmd[self.jlaunch_cmd.index('-HOST') + 1] host_str = self.host.name if self.use_hostcpu and -1 == host_str.find(":"): host_str += ":%d" % self.num_cpu return host_str
[docs] def describe(self): if self.status != JobStatus.LAUNCH_FAILURE: self._print("quiet", " Launch time: %s" % self.jctrl.LaunchTime) self._print("quiet", " Host : %s" % self._host_str()) self._print( "quiet", " Jobname : %s\n" % self.jobname + " Stage : %d (%s)" % (self.stage._INDEX, self.stage.NAME)) self._print( "verbose", " Prefix : %s\n" % self.prefix + " Jlaunch_cmd: %s\n" % subprocess.list2cmdline(self.jlaunch_cmd) + " Outputs : %s" % str(list(self.output))) if self.what: self._print("quiet", " Description: %s" % self.what)
[docs] def process_completed_job(self, jctrl: jobcontrol.Job, checkpoint_requested=False, restart_requested=False): """ Check for valid output and set status of job, assuming job is already complete. :param checkpoint_requested: Set to True if the job should checkpoint. Default if False. :param restart_requested: Set to True if the job should checkpoint and restart. Default if False. """ self.jctrl = jctrl # Make sure the job data has been downloaded and flushed to disk self.jctrl.download() # Not available on windows if hasattr(os, 'sync'): os.sync() self._print( "debug", "Job seems finished. Checking its exit-status and exit-code...") self._print("debug", "Job exit-status = '%s'" % self.jctrl.ExitStatus) if self.jctrl.ExitStatus == "killed": self._print("debug", "Job exit-code = N/A") self.status.set(JobStatus.KILLED) elif self.jctrl.ExitStatus == "fizzled": self._print("debug", "Job exit-code = N/A") self.status.set(JobStatus.FIZZLED) else: exit_code = self.jctrl.ExitCode if not exit_code_is_defined(self.jctrl): # If the exit code is not set, the backend must have died # without collecting the exit code. This could happen if a job # is qdeled, or the backend gets killed by OOM, or the job # monitoring process is killed by any reason. # Set status to a retriable status. self.status.set(JobStatus.KILLED) elif exit_code == 0: if checkpoint_requested: self.status.set(JobStatus.CHECKPOINT_REQUESTED) elif restart_requested: self.status.set(JobStatus.CHECKPOINT_WITH_RESTART_REQUESTED) else: self.output.check(self.status) elif exit_code == 17: # The mmlic3 library will return the following error codes upon # checkout: # 0 : success # 15 : temporary, retryable failure; perhaps the server # couldn't be contacted # 16 : all licenses are in use. SGE is capable of requeuing # the job. # 17 : fatal, unrecoverable license error. self.status.set(JobStatus.PERMANENT_LICENSE_FAILURE) elif exit_code in {15, 16}: self.status.set(JobStatus.TEMPORARY_LICENSE_FAILURE) else: self.status.set(JobStatus.BACKEND_ERROR)
[docs] def requeue(self, jctrl: jobcontrol.Job): # Make sure the job data has been downloaded and flushed to disk jctrl.download() # Not available on windows if hasattr(os, 'sync'): os.sync() # Delete stale checkpoint files that are not needed for restarting def _filter_tgz(input_fnames: List[str]): return set(filter(lambda x: x.endswith('-out.tgz'), input_fnames)) stale_input_tgz_fnames = _filter_tgz(jctrl.InputFiles) - _filter_tgz( jctrl.OutputFiles) for fname in stale_input_tgz_fnames: util.remove_file(fname) self._print("quiet", f"Restart checkpointed job: {self.jlaunch_cmd}") self._print("quiet", f"Deleted stale input files: {stale_input_tgz_fnames}") self.stage.restart_subjobs([self]) self.status.set(JobStatus.WAITING)
[docs] def finish(self): if self.status != JobStatus.LAUNCH_FAILURE: jobtime = Job.get_time(self.jctrl, self.num_cpu) self._print("quiet", "\n%s %s." % (str(self.jctrl), str(self.status))) self._print( "quiet", " Host : %s\n" % self._host_str() + " Launch time: %s\n" % jobtime.launch + " Start time : %s\n" % jobtime.start + " End time : %s\n" % jobtime.end + " Duration : %s\n" % jobtime.duration + " CPUs : %s\n" % self.num_cpu + " CPU time : %s\n" % jobtime.cpu_time + " Exit code : %s\n" % self.jctrl.ExitCode + " Jobname : %s\n" % self.jobname + " Stage : %d (%s)" % (self.stage._INDEX, self.stage.NAME), ) if self.err_handler: self.err_handler(self) if self.status.is_retriable(): self._print("quiet", " Retries : 0 - Job has failed too many times.") self.stage.finalize_job(self) self.stage.finalize_stage()
class _create_param_when_needed(object): def __init__(self, param): self._param = param def __get__(self, obj, cls): if cls == StageBase: a = sea.Map(self._param) a.add_tag("generic") else: a = None for c in cls.__bases__[::-1]: # left-most base takes precedence if issubclass(c, StageBase): if a is None: a = copy.deepcopy(c.PARAM) else: a.update(copy.deepcopy(c.PARAM)) a.update(self._param, tag="stagespec") setattr(cls, "PARAM", a) return a class _StageBaseMeta(PicklableMetaClass): def __init__(cls, name, bases, dict): PicklableMetaClass.__init__(cls, name, bases, dict) cls.stage_cls[cls.NAME] = cls
[docs]class StageBase(Picklable, metaclass=_StageBaseMeta): count = 0, Picklable stage_cls = {} stage_obj = {} # key = stage name; value = stage instance. NAME = "generic" RESTARTABLE = False # Whether or not a stage can be restarted after it's already ran # Basic stage parameters PARAM = _create_param_when_needed(""" DATA = { title = ? should_sync = true dryrun = false prefix = "" jobname = "$MAINJOBNAME_$STAGENO" dir = "$[$JOBPREFIX/$]$[$PREFIX/$]$MAINJOBNAME_$STAGENO" compress = "$MAINJOBNAME_$STAGENO-out%s" struct_output = "" should_skip = false effect_if = ? jlaunch_opt = [] transfer_asap = no } VALIDATE = { title = [{type = none} {type = str}] should_sync = {type = bool} dryrun = {type = bool} prefix = {type = str } jobname = {type = str } dir = {type = str } compress = {type = str } struct_output = {type = str } should_skip = {type = bool} effect_if = [{type = none} {type = list size = -2 _skip = all}] jlaunch_opt = { type = list size = 0 elem = {type = str} check = "" black_list = ["-HOST" "-USER" "-JOBNAME"] } transfer_asap = {type = bool} } """ % (PACKAGE_SUFFIX,))
[docs] def __init__(self, should_pack=True): # Will be set by the parser (see `parse_msj' function below). self.param = None self._PREV_STAGE = None # Stage object of the previous stage self._NEXT_STAGE = None # Stage object of the next stage self._ID = StageBase.count # ID number of this stage self._INDEX = None # Stage index. Not serialized. self._is_shown = False self._is_packed = False self._should_pack = should_pack # For parameter validation # function objects to be called before the main parameter check self._precheck = [] # function objects to be called after the main parameter check self._postcheck = [] self._files4pack = [] self._files4copy = [] self._pack_fname = "" # Has the `prestage' method been called? self._is_called_prestage = False self._used_jobname = [] self._start_time = None # Holds per stage start time self._stage_duration = None # Holds per stage duration time self._gpu_time = 0.0 # Accumulates total GPU time self._num_gpu_subjobs = 0 # Number of GPU subjobs self._packed_fnames = set() self._job_manager = JobManager() StageBase.count += 1
@property def jobs(self) -> List[Job]: return self._job_manager.jobs
[docs] def get_prejobs(self) -> List[Job]: """ Get the stage's input jobs """ if self._PREV_STAGE is None: return [] return self._PREV_STAGE.get_output_jobs()
[docs] def add_jobs(self, jobs: Iterable[Job]): """ Add jobs to the stage's job manager """ self._job_manager.add_jobs(jobs)
[docs] def add_job(self, job: Job): """ Shortcut for `add_jobs` """ self._job_manager.add_jobs([job])
[docs] def get_output_jobs(self) -> List[Job]: """ Get the stage's output jobs """ return self.filter_jobs(status=[JobStatus.SUCCESS], is_output=[True])
[docs] def filter_jobs(self, **kwargs) -> List[Job]: """ Return a list of jobs based on a matching a set of criteria given as arguments. Read `JobManager.filter_jobs` for available arguments. """ return self._job_manager.filter_jobs(**kwargs)
def __getstate__(self, state=None): state = state if (state) else picklejar.PickleState() state.NAME = self.NAME state._ID = self._ID state._is_shown = self._is_shown state._is_packed = self._is_packed state._job_manager = self._job_manager try: state._pack_fname = self._pack_fname except AttributeError: state._pack_fname = "" return state def __setstate__(self, state): if state.NAME != self.NAME: raise TypeError("Unmatched stage: %s vs %s" % (state.NAME, self.NAME)) self.__dict__.update(state.__dict__) def _print(self, loglevel, msg): _print(loglevel, msg) def _log(self, msg): self._print("quiet", "stage[%d] %s: %s" % (self._INDEX, self.NAME, msg)) def _get_macro_dict(self): macro_dict = copy.copy(ENGINE.macro_dict) macro_dict["$STAGENO"] = self._INDEX return macro_dict def _gen_unique_jobname(self, suggested_jobname): trial_jobname = suggested_jobname number = 1 while trial_jobname in self._used_jobname: trial_jobname = suggested_jobname + ("_%d" % number) number += 1 self._used_jobname.append(trial_jobname) sea.update_macro_dict({"$JOBNAME": trial_jobname}) return trial_jobname def _get_jobname_and_dir(self, job, macro_dict={}): # noqa: M511 sea.set_macro_dict(self._get_macro_dict()) sea.update_macro_dict(macro_dict) if self.param.prefix.val != "": sea.update_macro_dict({"$PREFIX": self.param.prefix.val}) if job.prefix != "" and job.prefix is not None: sea.update_macro_dict({"$JOBPREFIX": job.prefix}) try: if job.tag is not None: sea.update_macro_dict({"$JOBTAG": job.tag}) except AttributeError: pass util.chdir(ENGINE.base_dir) sea.update_macro_dict({"$JOBNAME": self.param.jobname.val}) return ( self.param.jobname.val, os.path.abspath(self.param.dir.val), ) def _param_jlaunch_opt_check(self, key, val_list, prefix, ev): try: black_list = set(self.PARAM.VALIDATE.jlaunch_opt.black_list.val) except AttributeError: return jlaunch_opt = set(val_list.val) bad_opt = jlaunch_opt & black_list if bad_opt: s = " ".join(bad_opt) ev.record_error( prefix, "Bad values for jlaunch_opt of %s stage: %s" % (self.NAME, s)) def _reg_param_precheck(self, func): if func not in self._precheck: self._precheck.append(func) def _reg_param_postcheck(self, func): if func not in self._postcheck: self._postcheck.append(func) def _set(self, key, setter, transformer=None): param = self.param[key] if param.has_tag("setbyuser"): if callable(setter): setter(param) elif isinstance(setter, sea.Atom): if callable(transformer): setter.val = transformer(param.val) else: setter.val = param.val def _effect(self, param): effect_if = param.effect_if if isinstance(effect_if, sea.List): for condition, block in zip(effect_if[0::2], effect_if[1::2]): # TODO: Don't use private function val = sea.evalor._eval(PARAM, condition) if isinstance(val, bool): condition = val elif isinstance(val[0], str): condition = _operator[val[0]](self, PARAM, val[1:]) else: condition = val[0] if condition: if isinstance(block, sea.Atom): block = sea.Map(block.val) # Checks if within the `block' is the 'effect_if' parameter # set. if "effect_if" not in block: block.effect_if = sea.Atom("none") # TODO what is the purpose of this line below? effect_if[1] = block block = block.dval param.update(block) self._effect(param) return param
[docs] def describe(self): self._print("quiet", "\nStage %d - %s" % (self._INDEX, self.NAME)) self._print("verbose", "{\n" + self.param.__str__(" ", tag="setbyuser") + "}")
[docs] def migrate_param(self, param: sea.Map): """ Subclasses can implement this to migrate params to provide backward compatibility with older msj files, ideally with a deprecation warning. """
[docs] def check_param(self): def clear_trjidx(prmdata): """ do not use idx files """ try: if "maeff_output" in prmdata: del prmdata["maeff_output"]["trjidx"] except (KeyError, TypeError): pass check_func_name = "multisim_stage_%d_jlaunch_opt_check" % self._ID self.PARAM.VALIDATE.jlaunch_opt.check.val = check_func_name sea.reg_xcheck(check_func_name, self._param_jlaunch_opt_check) # Note that `self.param's parent should be the global `PARAM'. # But this statement will implicitly change its parent to `self.PARAM'. # At the end of this function we need to change it back to `PARAM'. orig_param_data = self.PARAM.DATA self.PARAM.DATA = self.param clear_trjidx(self.PARAM.DATA) ev = sea.Evalor(self.param, "\n") for func in self._precheck: try: func() except ParseError as e: ev.record_error(err=str(e)) sea.check_map(self.PARAM.DATA, self.PARAM.VALIDATE, ev, "setbyuser") for func in self._postcheck: try: func() except ParseError as e: ev.record_error(err=str(e)) self.param.set_parent(PARAM.stage) self.PARAM.DATA = orig_param_data return ev
[docs] def push(self, job): if not self._is_called_prestage and not self.param.should_skip.val: self._is_called_prestage = True self.prestage() if job is None: self._print( "debug", "All surviving jobs have been pushed into stage[%d]." % self._INDEX) self.release() else: self._print( "debug", "Job was just pushed into stage[%d]: %s" % (self._INDEX, str(job)), ) if not self.param.should_sync.val: self.release()
[docs] def determine(self): param = self._effect(self.param) if param.should_skip.val: self.add_jobs(self.get_prejobs())
[docs] def crunch(self): """ This is where jobs of this stage are created. This function should be overriden by the subclass. """
[docs] def restart_subjobs(self, jobs): """ Subclass should override this if it supports subjob restarting. """
[docs] def release(self, is_restarting=False): """ Calls the 'crunch' method to generate new jobs objects and submits them to the 'QUEUE'. """ util.chdir(ENGINE.base_dir) if not self._is_shown: self.describe() self._is_shown = True is_restarting = True if not self.param.should_skip.val: self._is_packed = False if self._start_time is None: self._start_time = time.time() self.determine() if is_restarting: self.restart_subjobs(self.filter_jobs(is_restartable=[True])) if not self.filter_jobs(old=[False]) and not \ self.param.should_skip.val: # If no new jobs were created from restart_subjobs, run crunch self.crunch() jlaunch_opt = [str(e) for e in self.param.jlaunch_opt.val] if jlaunch_opt != [""]: for job in self.filter_jobs(is_for_jc=[True], status=[JobStatus.WAITING]): job.jlaunch_cmd += jlaunch_opt if not self.param.dryrun.val: ENGINE.write_checkpoint() self._job_manager.submit_jobs(QUEUE) for job in self.filter_jobs(is_for_jc=[False], old=[False]): if not job._has_run and callable(job.jlaunch_cmd): if not self.param.dryrun.val: job.jlaunch_cmd(job) job._has_run = True self.finalize_job(job) if self.param.dryrun.val: for job in self.filter_jobs(is_for_jc=[True], status=[JobStatus.WAITING]): job.status.set(JobStatus.SUCCESS) self.finalize_job(job) if self.jobs: self.finalize_stage()
[docs] def finalize_job(self, job: Job): """ Call `hook_captured_successful_job` on any successful jobs and write a checkpoint """ self._print("debug", "Captured %s" % job) if job.status == JobStatus.SUCCESS and not self.param.should_skip.val: self.hook_captured_successful_job(job) if self.param.transfer_asap.val: self.pack_stage(force=True) if job.USE_GPU: self._gpu_time += job.get_proc_time() self._num_gpu_subjobs += 1 ENGINE.write_checkpoint() self._print("debug", "running jobs:") self._print( "debug", self.filter_jobs(status=[JobStatus.WAITING, JobStatus.RUNNING], old=[False])) self._print("debug", "successful jobs:") self._print("debug", self.filter_jobs(status=[JobStatus.SUCCESS], old=[False])) self._print("debug", "failed jobs:") self._print("debug", self.filter_jobs(failed=[True], old=[False]))
[docs] def finalize_stage(self): """ If the stage is done running, pack the stage, and if the stage is successful, continue to the next stage """ running_jobs = self.filter_jobs( status=[JobStatus.WAITING, JobStatus.RUNNING], old=[False]) failed_jobs = self.filter_jobs(failed=[True], old=[False]) successful_jobs = self.filter_jobs(status=[JobStatus.SUCCESS], old=[False]) if not running_jobs: if not failed_jobs: # All jobs were successful if self.param.should_skip.val: self._print("quiet", f"\nStage {self._INDEX} is skipped.\n") else: self._print( "quiet", f"\nStage {self._INDEX} completed " f"successfully.\n") self.poststage() move_on = True elif successful_jobs and self._check_partial_success(): # Some stages can pass with partial success self.poststage() move_on = True else: # No jobs were successful self._print( "quiet", f"\nStage {self._INDEX} failed. " f"No subjobs completed.\n") move_on = False self.pack_stage(force=self.param.transfer_asap.val) if self._NEXT_STAGE is not None and move_on: self._NEXT_STAGE.push(None)
def _check_partial_success(self): """ Check whether or not the stage is considered successful based on whether it allows completion with some failed/some successful subjobs. Should be overridden by subclasses that need to implement this functionality. """ return False
[docs] def prestage(self): pass
[docs] def poststage(self): pass
[docs] def hook_captured_successful_job(self, job): pass
[docs] def time_stage(self): this_stop_time = time.time() self._stage_duration = util.time_duration(self._start_time, this_stop_time)
[docs] def pack_stage(self, force=False): if force or ((not self.param.should_skip.val) and self._should_pack and (not self._is_packed)): self._pack_stage()
def _pack_stage(self): self._is_packed = True util.chdir(ENGINE.base_dir) # Standard checkpoint to a file pack_fname = None if self.param.compress.val != "": sea.update_macro_dict({"$STAGENO": self._INDEX}) pack_fname = self.param.compress.val if not pack_fname.lower().endswith(( PACKAGE_SUFFIX, "tar.gz", )): pack_fname += PACKAGE_SUFFIX self.param.compress.val = pack_fname self._pack_fname = pack_fname print_debug(f"pack_stage: pack_fname:{pack_fname}") # Collects all data paths for transferring. data_paths = set() for job in self.jobs: # Some stages just pass on a job from the previous stage # directly to the next stage. So we check the stage ID # to avoid packing the same job again. if job.stage._ID == self._ID: if job.dir and pack_fname: data_paths.add(job.dir) else: for e in job.output: data_paths.add(e) reg_file = [] if isinstance(job.jctrl, jobcontrol.Job): reg_file.extend(job.jctrl.OutputFiles) reg_file.extend(job.jctrl.InputFiles) reg_file.extend(job.jctrl.LogFiles) if job.jctrl.StructureOutputFile: reg_file.append(job.jctrl.StructureOutputFile) for fname in reg_file: if not os.path.isabs(fname): data_paths.add(os.path.join(job.dir, fname)) # Creates a stage-specific checkpoint file -- just a symbolic link # to the current checkpoint file. ENGINE.write_checkpoint() stage_checkpoint_fname = None if os.path.isfile(CHECKPOINT_FNAME): stage_checkpoint_fname = (os.path.basename(CHECKPOINT_FNAME) + "_" + str(self._INDEX)) shutil.copyfile(CHECKPOINT_FNAME, stage_checkpoint_fname) # Includes this checkpoint file for transferring. data_paths.add(os.path.abspath(stage_checkpoint_fname)) if pack_fname: with tarfile.open(pack_fname, mode="w:gz", format=tarfile.GNU_FORMAT, compresslevel=1) as pack_file: pack_file.dereference = True for path in data_paths | set(self._files4pack): print_debug(f"pack_stage: add_to_tar: {path} exists: " f"{os.path.exists(path)} cwd: {os.getcwd()}") if os.path.exists(path): relpath = util.relpath(path, ENGINE.base_dir) pack_file.add(relpath) data_paths = [pack_fname] if ENGINE.JOBBE: for path in data_paths: # Makes all paths relative. Otherwise jobcontrol won't # transfer them!!! path = util.relpath(path, ENGINE.base_dir) if not path: continue if path in self._packed_fnames: continue self._packed_fnames.add(path) print_debug(f"pack_stage: outputFile: {path} relpath: " f"{util.relpath(path, ENGINE.base_dir)} " f"cwd: {os.getcwd()}") # Only when we do NOT compress files, we allow to transfer # files ASAP. DESMOND-7401. if (self.param.transfer_asap.val and not pack_fname) and os.path.exists(path): ENGINE.JOBBE.copyOutputFile(path) else: ENGINE.JOBBE.addOutputFile(path) for path in self._files4copy: path = util.relpath(path, ENGINE.base_dir) if path in self._packed_fnames: continue self._packed_fnames.add(path) print_debug(f"pack_stage: files4copy: {path} relpath: " f"{util.relpath(path, ENGINE.base_dir)} " f"cwd: {os.getcwd()}") ENGINE.JOBBE.copyOutputFile(path) try: self.time_stage() self._print( "quiet", "Stage %d duration: %s\n" % (self._INDEX, self._stage_duration)) except TypeError: self._print( "quiet", "Stage %d duration could not be calculated." % self._INDEX)
[docs]class JobManager: """ A class for managing a stage's jobs. The jobs are stored in the `_jobs` list internally but should only be accessed by the `jobs` property or `filter_jobs`. """
[docs] def __init__(self): self._jobs: List[Job] = []
@property def jobs(self) -> List[Job]: return [*self._jobs] # Return copy so list is not modified by user
[docs] def clear(self): self._jobs = []
[docs] def add_jobs(self, jobs: Iterable[Job]): """ Add the given jobs to the job manager but does not add duplicate jobs """ for job in jobs: job.old = False existing_jobs = set(self.jobs) self._jobs.extend(job for job in jobs if job not in existing_jobs)
[docs] def submit_jobs(self, queue: queue.Queue): jobs = self.filter_jobs(status=[JobStatus.WAITING], is_for_jc=[True]) queue.push(jobs)
[docs] def filter_jobs(self, status=None, old=None, is_for_jc=None, is_output=None, failed=None, is_launchable=None, is_restartable=None, is_incomplete=None) -> List[Job]: """ Get a subset of the job manager's jobs. Each argument can either be None, to indicate no filtering on the property, or a list of acceptable values for the given argument's property. When passing in multiple arguments, the function returns jobs which satisfy all given criteria. """ def _filter_job(job): if status: if job.status not in status: return False if old: if job.old not in old: return False if is_for_jc: if job.is_for_jc not in is_for_jc: return False if is_output: if job.is_output not in is_output: return False if failed: if job.failed not in failed: return False if is_launchable: if job.is_launchable not in is_launchable: return False if is_restartable: if job.is_restartable not in is_restartable: return False if is_incomplete: if job.is_incomplete not in is_incomplete: return False return True return [*filter(_filter_job, self.jobs)]
[docs]class StructureStageBase(StageBase): """ StructureStageBase can be used for stages that take in a path to a structure, apply some transformation, and then write out an updated structure. """
[docs] def __init__(self, *args, **kwargs): self.TAG = self.NAME.upper() super().__init__(*args, **kwargs)
[docs] def crunch(self): self._print("debug", f"In {self.NAME}.crunch") for pj in self.get_prejobs(): jobname, jobdir = self._get_jobname_and_dir(pj) if not os.path.isdir(jobdir): os.makedirs(jobdir) with fileutils.chdir(jobdir): new_job = copy.deepcopy(pj) new_job.stage = weakref.proxy(self) new_job.output = JobOutput() new_job.need_host = False new_job.dir = jobdir new_job.status.set(JobStatus.SUCCESS) new_job.parent = pj output_fname = self.run(jobname, pj.output.struct_file()) if output_fname is None: new_job.status.set(JobStatus.BACKEND_ERROR) else: new_job.output.set_struct_file( os.path.abspath(output_fname)) self.add_job(new_job) self._print("debug", f"Out {self.NAME}.crunch")
[docs] def run(self, jobname: str, input_fname: str) -> Optional[str]: """ :param jobname: Jobname for this stage. :param input_fname: Filename for the input structure. :return: Filename for the output structure or `None` if there was an error generating the output. """ raise NotImplementedError
class _get_jc_backend_when_needed(object): def __get__(self, obj, cls): jobbe = jobcontrol.get_backend() setattr(cls, "JOBBE", jobbe) return jobbe
[docs]class Engine(object): JOBBE = _get_jc_backend_when_needed()
[docs] def __init__(self, opt=None): # This may be reset by the command options. self.jobname = None self.username = None self.mainhost = None self.host = None self.cpu = None self.inp_fname = None self.msj_fname = None # The .msj file of this restarting job. self.MSJ_FNAME = None # Original .msj file name. self.msj_content = None self.out_fname = None # Not serialized because it will be always reset at restarting self.set = None self.cfg = None self.cfg_content = None self.maxjob = None self.max_retry = None self.relay_arg = None self.launch_dir = None self.description = None self.loglevel = GENERAL_LOGLEVEL self.stage = [] # Serialized. Will be set when serialization. self.date = None # Date of the original job. self.time = None # Time of the original job. self.START_TIME = None # Start time of the original job. self.start_time = None # Start time. Will change in restarting. self.stop_time = None # Stop time. Will change in restarting. self.base_dir = None # Current base dir. Will change in restarting. # Stage No. to restart from. Will change in restarting. self.refrom = None self.base_dir_ = None # Base dir of last job. self.jobid = None # Current job ID. Will change in restarting. # Job ID of the original job. Not affected by restarting. self.JOBID = None # version numbers and installation will change in restarting self.version = VERSION # MSJ version. self.build = BUILD self.mmshare_ve = envir.CONST.MMSHARE_VERSION # Installation dir. Will change in restarting. self.schrodinger = envir.CONST.SCHRODINGER # Installation dir of the previous run. Will change in restarting. self.schrod_old = None self.old_jobnames = [] # Will be set when probing the checkpoint file self.chkpt_fname = None self.chkpt_fh = None self.restart_stage = None self.__more_init() if opt: self.reset(opt)
def __more_init(self): """ Will be called by '__init__' and 'deseriealize'. This is introduced to avoid breaking the previous checkpoint file by adding a new attribute. """ self.notify = None self.macro_dict = None self.max_walltime = None self.checkpoint_requested_event = None def __find_restart_stage_helper(self, stage): if stage.filter_jobs(is_incomplete=[True]) or not stage.jobs: self.restart_stage = self.restart_stage if ( self.restart_stage) else stage stage._is_shown = False stage._is_packed = False def _find_restart_stage(self): self.restart_stage = None self._foreach_stage(self.__find_restart_stage_helper) def _fix_job(self, stage): if not stage.RESTARTABLE and stage.filter_jobs(is_incomplete=[True]): stage._job_manager.clear() for job in stage.jobs: job.old = True if job.dir and job.dir.startswith(self.base_dir_ + os.sep): job.dir = job.dir.replace(self.base_dir_, self.base_dir) elif job.dir and job.dir == self.base_dir_: # With JOB_SERVER, the job dir may not be a subdirectory # so replace the top-level dir too. This is needed for # restarting MD jobs from the production stage. job.dir = job.dir.replace(self.base_dir_, self.base_dir) job.output.update_basedir(self.base_dir_, self.base_dir) try: job.input.update_basedir(self.base_dir_, self.base_dir) except AttributeError: pass # Fixes the "stage" attribute of all jobs of this stage. And fixes job # launching command. if isinstance(job.stage, int): job.stage = weakref.proxy(self.stage[job.stage]) if isinstance(job.jlaunch_cmd, list) and isinstance( job.jlaunch_cmd[0], str): job.jlaunch_cmd[0] = job.jlaunch_cmd[0].replace( self.schrod_old, self.schrodinger)
[docs] def restore_stages(self, print_func=print_quiet): # DESMOND-7934: Preserve the task stage from the checkpoint # if a custom msj is specified. checkpoint_stage_list = None if self.msj_fname and self.msj_content: checkpoint_stage_list = parse_msj(None, msj_content=self.msj_content, pset=self.set) parsee0 = "the multisim script file" if (self.msj_fname) else None parsee1 = "the '-set' option" if (self.set) else None parsee = (parsee0 + " and " + parsee1 if (parsee0 and parsee1) else parsee0 if (parsee0) else parsee1) if parsee: print_func("\nParsing %s..." % parsee) try: msj_content = None if (self.msj_fname) else self.msj_content stage_list = parse_msj(self.msj_fname, msj_content, self.set) except ParseError as a_name_to_make_flake8_happy: print_quiet("\n%s\nParsing failed." % str(a_name_to_make_flake8_happy)) sys.exit(1) if checkpoint_stage_list and stage_list: refrom = self.refrom # Find the restart stage index if not specified if refrom is None: # The first stage has the parameters we want # restore from the checkpoint. refrom = 2 for idx, s in enumerate(self.stage): if s.filter_jobs(is_incomplete=[True]) or not s.jobs: refrom = idx break # Restore stages before the restart stage from the checkpoint # and update the ones after the checkpoint stage_list = checkpoint_stage_list[:refrom - 1] + stage_list[refrom - 1:] if "task" != stage_list[0].NAME: print("ERROR: The first stage is not a 'task' stage.") sys.exit(1) if self.cfg: with open(self.cfg, "r") as fh: cfg = sea.Map(fh.read()) for stage in stage_list: if "task" == stage.NAME: if "desmond" in stage.param.set_family: stage.param.set_family.desmond.update(cfg) else: stage.param.set_family["desmond"] = cfg if self.cpu: # Value of `self.cpu' is a string, which specifies either a single # integer or 3 integers separated by spaces. We must parse the # string to get the integers and assign the latter to stages. cpu_str = self.cpu.split() try: cpu = [int(e) for e in cpu_str] n_cpu = len(cpu) cpu = cpu[0] if (1 == n_cpu) else cpu if 1 != n_cpu and 3 != n_cpu: raise ValueError("Incorrect configuration of the CPU: %s" % self.cpu) except ValueError: raise ParseError("Invalid value for the 'cpu' parameter: '%s'" % self.cpu) for stage in stage_list: if stage.NAME in [ "simulate", "minimize", "replica_exchange", "lambda_hopping", "vrun", "fep_vrun", "watermap", ]: stage.param["cpu"] = cpu stage.param.cpu.add_tag("setbyuser") elif stage.NAME in [ "mcpro_simulate", "watermap_cluster", "ffbuilder", ]: stage.param["cpu"] = (cpu if (1 == n_cpu) else (cpu[0] * cpu[1] * cpu[2])) stage.param.cpu.add_tag("setbyuser") stage_state = [ e.__getstate__() for e in (self.stage[:self.refrom] if ( self.refrom and self.refrom > 0) else self.stage) ] self.stage = build_stages(stage_list, self.out_fname, stage_state) for stage in self.stage: self._fix_job(stage) # `self.msj_content' contains only user's settings. `stage_list[1:-1]' # will avoid the initial ``primer'' and the final ``concluder'' stages. self.msj_content = write_msj(stage_list[1:-1], to_str=True)
[docs] def reset(self, opt): """ Resets this engine with the command options. """ # Resets the '_is_reset_*' attributes. for k in self.__dict__: if k[:10] == "_is_reset_": self.__dict__[k] = False if opt.refrom: self.refrom = opt.refrom if opt.jobname: self.jobname = opt.jobname if opt.user: self.username = opt.user if opt.mainhost: self.mainhost = opt.mainhost if opt.host: self.host = opt.host if opt.cpu: self.cpu = opt.cpu if opt.inp: self.inp_fname = os.path.abspath(opt.inp) if opt.msj: self.msj_fname = os.path.abspath(opt.msj) if opt.out: self.out_fname = opt.out if opt.set: self.set = opt.set if opt.maxjob is not None: self.maxjob = opt.maxjob if opt.max_retries is not None: self.max_retry = opt.max_retries if opt.relay_arg: self.relay_arg = sea.Map(opt.relay_arg) if opt.launch_dir: self.launch_dir = opt.launch_dir if opt.notify: self.notify = opt.notify if opt.encoded_description: self.description = cmdline.get_b64decoded_str( opt.encoded_description) if opt.quiet: self.loglevel = "quiet" if opt.verbose: self.loglevel = "verbose" if opt.debug: self.loglevel = "debug" if opt.max_walltime: self.max_walltime = opt.max_walltime self.cfg = opt.cfg
[docs] def boot_setup(self, base_dir=None): """ Set up an `Engine` object, but do not start the queue. :param base_dir: Set to the path for the base_dir or `None`, the default, to use the cwd. """ global ENGINE, GENERAL_LOGLEVEL, CHECKPOINT_FNAME self._init_signals() GENERAL_LOGLEVEL = self.loglevel ENGINE = self if self.loglevel == "debug": _print("quiet", "Multisim debugging mode is on.\n") if self.description: _print("quiet", self.description) ####################################################################### # Boots the engine. _print("quiet", "Booting the multisim workflow engine...") self.date = time.strftime("%Y%m%d") if (not self.date) else self.date self.time = time.strftime("%Y%m%dT%H%M%S") if ( not self.time) else self.time self.start_time = time.time() if ( not self.start_time) else self.start_time self.base_dir_ = self.base_dir self.base_dir = base_dir or os.getcwd() self.jobid = envir.get("SCHRODINGER_JOBID") self.JOBID = self.JOBID if (self.JOBID) else self.jobid self.maxjob = 0 if self.maxjob < 1 else self.maxjob self.max_retry = (self.max_retry if (self.max_retry is not None) else int( envir.get("SCHRODINGER_MAX_RETRIES", 3))) self.MSJ_FNAME = self.MSJ_FNAME if (self.MSJ_FNAME) else self.msj_fname # Resets these variables. self.version = VERSION self.build = BUILD self.mmshare_ver = envir.CONST.MMSHARE_VERSION self.schrod_old = self.schrodinger self.schrodinger = envir.CONST.SCHRODINGER _print("quiet", " multisim version: %s" % self.version) _print("quiet", " mmshare version: %s" % self.mmshare_ver) _print("quiet", " Jobname: %s" % self.jobname) _print("quiet", " Username: %s" % self.username) _print("quiet", " Main job host: %s" % self.mainhost) _print("quiet", " Subjob host: %s" % self.host) _print("quiet", " Job ID: %s" % self.jobid) _print( "quiet", " multisim script: %s" % os.path.basename(self.msj_fname if (self.msj_fname) else self.MSJ_FNAME)) _print( "quiet", " Structure input file: %s" % os.path.basename(self.inp_fname)) if self.cpu: _print("quiet", ' CPUs per subjob: "%s"' % self.cpu) else: _print("quiet", " CPUs per subjob: (unspecified in command)") _print("quiet", " Job start time: %s" % time.ctime(self.start_time)) _print("quiet", " Launch directory: %s" % self.launch_dir) _print("quiet", " $SCHRODINGER: %s" % self.schrodinger) if oplsdir := os.getenv(mm.OPLS_DIR_ENV): _print("quiet", f" $OPLS_DIR: {os.path.basename(oplsdir)}") # Only need to copy the file once if os.getenv(constants.SCHRODINGER_MULTISIM_DONT_COPY_OPLS) is None: os.environ[constants.SCHRODINGER_MULTISIM_DONT_COPY_OPLS] = '1' opls_fname = f'{self.jobname}-out.opls' if not Path(opls_fname).exists(): shutil.copyfile(oplsdir, opls_fname) self.JOBBE.copyOutputFile(opls_fname) else: _print("quiet", " $OPLS_DIR: <empty>") sys.stdout.flush() self.macro_dict = { "$MAINJOBNAME": self.jobname, "$MASTERJOBNAME": self.jobname, # TODO: Here to read old msj files "$USERNAME": self.username, "$SUBHOST": self.host, } sea.set_macro_dict(copy.copy(self.macro_dict)) self.restore_stages() if self.chkpt_fh: def show_job_state(stage, engine=self): engine._check_stage(stage) if stage._final_status[0] == "1": _print("quiet", " Jobnames of failed subjobs:") for job in stage.filter_jobs(failed=[True]): _print("quiet", " %s" % job.jobname) _print("quiet", "") _print("quiet", "Checkpoint state:") self._foreach_stage(show_job_state) _print("quiet", "\nSummary of user stages:") for stage in self.stage[1:-1]: if stage.param.title.val: _print( "quiet", " stage %d - %s, %s" % (stage._INDEX, stage.NAME, stage.param.title.val)) else: _print("quiet", " stage %d - %s" % (stage._INDEX, stage.NAME)) _print("quiet", "(%d stages in total)" % (len(self.stage) - 2)) CHECKPOINT_FNAME = os.path.join( self.base_dir, sea.expand_macro(CHECKPOINT_FNAME, sea.get_macro_dict()))
[docs] def boot(self): """ Boot the `Engine` and run the jobs. """ global QUEUE self.boot_setup() max_walltime_timer = None if self.max_walltime: self.checkpoint_requested_event = threading.Event() _print("quiet", f"Checkpoint after {self.max_walltime} seconds.") max_walltime_timer = threading.Timer( self.max_walltime, lambda: self.checkpoint_requested_event.set()) max_walltime_timer.start() if self.host is None: _print( "quiet", "\nCould not determine host. " "Please check schrodinger.hosts and the queue configuration and try again." ) self.cleanup(exit_code=1, skip_stage_check=True) return QUEUE = queue.Queue(self.host, self.maxjob, max_retries=self.max_retry, periodic_callback=self.handle_jobcontrol_message) self.JOBBE.addOutputFile(os.path.basename(CHECKPOINT_FNAME)) self.start_time = time.time() _print("quiet", "\nWorkflow is started now.") try: if self.START_TIME is None: self.START_TIME = self.start_time self.stage[0].start(self.inp_fname) else: if self.refrom is None or self.refrom < 1: self._find_restart_stage() else: self.restart_stage = self.stage[self.refrom] if self.restart_stage: if self.msj_fname: _print( "quiet", "Updating stages with the new .msj file: " f"{self.msj_fname}...") _print( "quiet", f"Stage {self.restart_stage._INDEX} and after " "will be affected by the new .msj file.") # We need to rerun the `set_family' functions. self.run_set_family(self.restart_stage._INDEX) _print( "quiet", "Restart workflow from stage %d." % self.restart_stage._INDEX) self.restart_stage.push(None) else: _print( "quiet", "The previous multisim job has completed successfully.") _print( "quiet", "If you want to restart from a completed stage, " "specify its stage number to") _print( "quiet", "the '-RESTART' option as: " "-RESTART <checkpoint-file>:<stage_number>.") QUEUE.run() exit_code = 0 skip_stage_check = False except SystemExit: sys.exit(1) except StopRequest: restart_fname = queue.CHECKPOINT_REQUESTED_FILENAME with open(restart_fname, 'w') as f: pass self.JOBBE.addOutputFile(os.path.basename(restart_fname)) exit_code = 0 skip_stage_check = True except StopAndRestartRequest: restart_fname = queue.CHECKPOINT_WITH_RESTART_REQUESTED_FILENAME with open(restart_fname, 'w') as f: pass self.JOBBE.addOutputFile(os.path.basename(restart_fname)) exit_code = 0 skip_stage_check = True except Exception: ei = sys.exc_info() sys.excepthook(ei[0], ei[1], ei[2]) _print( "quiet", "\n\nUnexpected exception occurred. Terminating the " "multisim execution...") exit_code = 1 skip_stage_check = False if max_walltime_timer is not None: max_walltime_timer.cancel() self.cleanup(exit_code, skip_stage_check=skip_stage_check)
[docs] def run_set_family(self, max_stage_idx=None): """ Re-run set_family for all task stages up to `max_stage_idx`. """ max_stage_idx = max_stage_idx or len(self.stage) stage = self.stage[0] while stage is not None and stage._INDEX < max_stage_idx: if stage.NAME == "task": stage.set_family() stage = stage._NEXT_STAGE
[docs] def handle_jobcontrol_message(self, stop=False): restart = False if self.checkpoint_requested_event is not None: restart = self.checkpoint_requested_event.is_set() if not (stop or restart or self.JOBBE.haltRequested()): return _print("quiet", "\nRecieved 'halt' message. Stopping job on user's request...") _print("quiet", f"{len(QUEUE.running_jobs)} subjob(s) are currently running.") num_killed = QUEUE.stop() if num_killed: _print("quiet", f"{num_killed} subjob(s) failed to stop and were killed.") else: _print("quiet", "Subjobs stopped successfully.") if restart: raise StopAndRestartRequest() raise StopRequest()
def _init_signals(self): # Signal handling stuff. for signal_name in [ "SIGTERM", "SIGINT", "SIGHUP", "SIGUSR1", "SIGUSR2" ]: # Certain signals are not available depending on the OS. if hasattr(signal, signal_name): signal.signal( getattr(signal, signal_name), lambda x, stack_frame: self._handle_signal(signal_name), ) def _reset_signals(self): signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGUSR1, signal.SIG_DFL) signal.signal(signal.SIGUSR2, signal.SIG_DFL) try: signal.signal(signal.SIGHUP, signal.SIG_DFL) except AttributeError: pass def _handle_signal(self, signal_name): self._reset_signals() print("\n\n%s: %s signal received" % (time.asctime(), signal_name)) return self.handle_jobcontrol_message(stop=True) def _foreach_stage(self, callback): stage = self.stage[0]._NEXT_STAGE while stage._NEXT_STAGE is not None: callback(stage) stage = stage._NEXT_STAGE def _check_stage(self, stage, print_func=print_quiet): INTERPRETATION = { -2: "2 was skipped", -1: "0 not run", 0: "0 failed", 1: "1 partially completed", 2: "2 completed", } subjob = "" if stage._is_shown: if stage.param.should_skip.val: status = INTERPRETATION[-2] else: num_done = len( stage.filter_jobs(status=[JobStatus.SUCCESS], old=[False])) num_incomplete = len( stage.filter_jobs(is_incomplete=[True], old=[False])) if num_done > 0: if num_incomplete == 0: status = INTERPRETATION[2] else: status = INTERPRETATION[1] subjob = " %d subjobs failed, %d subjobs done." % ( num_incomplete, num_done) else: if num_incomplete > 0: status = INTERPRETATION[0] else: status = INTERPRETATION[-1] else: status = INTERPRETATION[-1] print_func(" Stage %d %s.%s" % (stage._INDEX, status[2:], subjob)) stage._final_status = status
[docs] def cleanup(self, exit_code=0, skip_stage_check=False): """ :param skip_stage_check: Set to True to skip checking each stage to determine the exit code. """ print("Cleaning up files...") sys.stdout.flush() self._foreach_stage( lambda stage: stage._is_shown and stage.pack_stage()) self.stop_time = time.time() job_duration = util.time_duration(self.start_time, self.stop_time) print("\nMultisim summary (%s):" % time.ctime(self.stop_time)) self._foreach_stage(self._check_stage) # FIXME: duration for this restarting? print(" Total duration: %s" % job_duration) all_gpu_times = [] all_gpu_subjobs = [] self._foreach_stage(lambda stage: all_gpu_times.append(stage._gpu_time)) self._foreach_stage( lambda stage: all_gpu_subjobs.append(stage._num_gpu_subjobs)) total_gpu_time = sum(all_gpu_times) if total_gpu_time: print(" Total GPU time: %s (used by %d subjob(s))" % (_time_to_time_str(total_gpu_time), sum(all_gpu_subjobs))) final_status = [] for stage in self.stage[1:-1]: if stage.filter_jobs(old=[False]): final_status.append(int(stage._final_status[0])) if final_status: is_successful = min(final_status) else: is_successful = 0 # Fail if no stages ran if exit_code == 0: if is_successful == 2: print("Multisim completed.") elif is_successful == 1: print("Multisim partially completed.") else: print("Multisim failed.") else: print("Multisim failed.") if self.notify: recipients = (self.notify if (isinstance(self.notify, list)) else [ self.notify, ]) print("\nSending log file to the email address(es): %s" % ", ".join(recipients)) sys.stdout.flush() log_fname = self.jobname + "_multisim.log" if os.path.isfile(log_fname): email_message = open(log_fname, "r").read() else: email_message = "Log file: %s not found.\n" email_message += str(self.JOBID) + "\n" email_message += self.launch_dir + "\n" email_message += self.description + "\n" if exit_code == 0: if is_successful == 2: email_message += "Multisim completed." elif is_successful == 1: email_message += "Multisim partially completed." else: email_message += "Multisim failed." else: email_message += "Multisim failed." import smtplib from email.mime.text import MIMEText composer = MIMEText(email_message) composer["Subject"] = "Multisim: %s" % self.jobname composer["From"] = "noreply@schrodinger.com" composer["To"] = ", ".join(recipients) try: smtp = smtplib.SMTP() smtp.connect() smtp.sendmail("noreply@schrodinger.com", recipients, composer.as_string()) smtp.close() except Exception: print("WARNING: Failed to send notification email.") print("WARNING: There is probably no SMTP server running on " "main host.") if exit_code == 0 and is_successful != 2 and not skip_stage_check: exit_code = 1 sys.exit(exit_code)
[docs] def serialize(self, fh: BinaryIO): self.msj_fname = None self.set = None self.refrom = None self.chkpt_fh = None self.stop_time = time.ctime() pickle.dump(self, fh) PickleJar.serialize(fh)
[docs] def serialize_bytes(self) -> bytes: """ Return the binary contents of the serialized engine. """ fh = BytesIO() self.serialize(fh) fh.flush() return fh.getvalue()
def __getstate__(self): tmp_dict = copy.copy(self.__dict__) # Can't checkpoint event tmp_dict["checkpoint_requested_event"] = None return tmp_dict
[docs] @staticmethod def deserialize(fh: BinaryIO): unpickler = picklejar.CustomUnpickler(fh, encoding="latin1") engine = unpickler.load() # This adds class metadata that was serialized # above. Without this, these values are reset to # the default. PickleJar.deserialize(fh) engine.chkpt_fh = fh engine.__more_init() try: engine.old_jobnames.append(engine.jobname) except AttributeError: engine.old_jobnames = [ engine.jobname, ] return engine
[docs] def write_checkpoint(self, fname=None, num_retry=10): if not fname: fname = CHECKPOINT_FNAME # Write to a temporary file fname_lock = fname + ".lock" with open(fname_lock, "wb") as fh: self.serialize(fh) for i in range(num_retry): try: # not available in py2 os.replace(fname_lock, fname) return except AttributeError: # rename fails on Windows if the destination already exists if os.path.isfile(fname): os.remove(fname) os.rename(fname_lock, fname) except PermissionError as err: # TODO: DESMOND-9511 print(i, os.getcwd(), fname_lock, fname) for fn in glob.glob("*"): print(i, fn) if i == num_retry - 1: raise err else: print(f"retry {i+1} due to err: {err}") time.sleep(30)
[docs]class StopRequest(Exception): pass
[docs]class StopAndRestartRequest(Exception): pass
[docs]class ParseError(Exception): pass
[docs]def is_restartable_version(version_string): version_number = [int(e) for e in version_string.split(".")] current = [int(e) for e in VERSION.split(".")] for v, c in zip(version_number[:3], current[:3]): if v < c: return False return True
[docs]def is_restartable_build(engine): from . import bld_def as bd bld_comm = bd.bld_types[bd.DESMOND_COMMERCIAL] try: restart_files_build = engine.build except AttributeError: return True return restart_files_build != bld_comm or BUILD == bld_comm
[docs]def build_stages(stage_list, out_fname=None, stage_state=[]): # noqa: M511 """ Build up the stages for the job, adding the initial Primer and final Concluder stages. """ import schrodinger.application.desmond.stage as stg primer_stage = stg.Primer() concluder_stage = stg.Concluder(out_fname) primer_stage.param = copy.deepcopy(stg.Primer.PARAM.DATA) concluder_stage.param = copy.deepcopy(stg.Concluder.PARAM.DATA) stage_list.insert(0, primer_stage) stage_list.append(concluder_stage) build_stagelinks(stage_list) for stage, state in zip(stage_list, stage_state): if stage.NAME == state.NAME: stage.__setstate__(state) return stage_list
[docs]def probe_checkpoint(fname, indent=""): print(indent + "Probing checkpoint file: %s" % fname) with open(fname, "rb") as fh: engine = Engine.deserialize(fh) engine.schrod_old = engine.schrodinger def probe_print(s): print(indent + " " + s) probe_print(" multisim version: %s" % engine.version) probe_print(" mmshare version: %s" % engine.mmshare_ver) probe_print(" Jobname: %s" % engine.jobname) probe_print(" Previous jobnames: %s" % engine.old_jobnames) probe_print(" Username: %s" % engine.username) probe_print(" Main job host: %s" % engine.mainhost) probe_print(" Subjob host: %s" % engine.host) if engine.cpu: probe_print(' CPUs per subjob: "%s"' % engine.cpu) else: probe_print(" CPUs per subjob: unspeficied in command") probe_print(" Original start time: %s" % time.ctime(engine.START_TIME)) probe_print(" Checkpoint time: %s" % engine.stop_time) probe_print(" Main job ID: %s" % engine.jobid) probe_print(" Structure input file: %s" % os.path.basename(engine.inp_fname)) probe_print(" Original *.msj file: %s" % os.path.basename(engine.MSJ_FNAME)) engine.base_dir_ = engine.base_dir engine.restore_stages(print_func=print_tonull) probe_print("\nStages:") engine.chkpt_fname = fname def show_failed_jobs(stage, engine=engine): engine._check_stage(stage, probe_print) if stage._final_status[0] == "1": probe_print(" Jobnames of failed subjobs:") for job in stage.filter_jobs(failed=[True]): probe_print(" %s" % job.jobname) engine._foreach_stage(show_failed_jobs) print() print("Current version of multisim is %s" % VERSION) print("This checkpoint file " "can%sbe restarted with the current version of multisim." % (" " if (is_restartable_version(engine.version)) else " not ")) return engine
[docs]def escape_string(s): ret = "" should_quote = False if s == "": return '""' for c in s: if c == '"': ret += '\\"' should_quote = True elif c == "'" and ret[-1] == "\\": ret = ret[:-1] + "'" should_quote = True else: ret += c if c <= " ": should_quote = True if should_quote: ret = '"' + ret + '"' return ret
[docs]def append_stage( cmj_fname, stage_type, cfg_file=None, jobname=None, dir=None, compress=None, parameter={}, # noqa: M511 ): if not os.path.isfile(cmj_fname): return None try: fh = open(cmj_fname, "r") s = fh.read() fh.close() except IOError: print("error: Reading failed. file: '%s'", cmj_fname) return None if stage_type == "simulate": s += "simulate {\n" elif stage_type == "minimize": s += "minimize {\n" elif stage_type == "replica_exchange": s += "replica_exchange {\n" else: print("error: Unknown stage type '%s'" % stage_type) return None if cfg_file is not None: s += ' cfg_file = "%s"\n' % cfg_file if jobname is not None: s += ' jobname = "%s"\n' % jobname if dir is not None: s += ' dir = "%s"\n' % dir if compress is not None: s += ' compress = "%s"\n' % compress for p in parameter: if parameter[p] is not None: s += " %s = %s\n" % (p, parameter[p]) s += "}\n" return s
[docs]def concatenate_relaxation_stages(raw): """ Attempts to concatenate relaxation stages by finding all adjacent non-production `simulate` stages. If no concatenatable stages are found, None is returned. Otherwise, a new raw map with the relaxation `simulate` stages replaced with a single `concatenate` stage is returned. :param raw: the raw map representing the MSJ :type raw: `sea.Map` :return: a new raw map representing the updated msj, or None. :rtype: `sea.Map` or `None` """ new_raw = copy.deepcopy(raw) while True: stages_to_concat, insertion_point = get_concat_stages(new_raw.stage) if len(stages_to_concat) > 1: concat_stage = sea.Map() concat_stage.__NAME__ = "concatenate" concat_simulate_stages = sea.List() concat_simulate_stages.add_tag("setbyuser") for stage in stages_to_concat: new_raw.stage.remove(stage) concat_simulate_stages.append(stage) concat_stage.simulate = concat_simulate_stages concat_stage.title = concat_stage.simulate[0].title if 'maeff_output' in concat_stage.simulate[0].val: concat_stage.maeff_output = concat_stage.simulate[ 0].maeff_output new_raw.stage.insert(insertion_point, concat_stage) new_raw.stage.add_tag("setbyuser", propagate=False) else: break if len(new_raw.stage) != len(raw.stage): return new_raw return None
[docs]def get_concat_stages(stages, param_attr=""): """ Get a list of the stages that can be concatenated together, and the insertion point of the resulting concatenate stage. Stages can be concatenated if they are adjacent simulate stages with the same restraints, excluding the final production stage, which can be lambda hopping, replica exchange, or otherwise the last simulate stage. :param stages: A list of objects representing multisim stages. For flexibility, these can be either maps or stages. For stages, a param attribute must be passed that will give the location of the param on the stage. :type stages: list of (sea.Map or stage.Stage) :param param_attr: optional name of the attribute of the objects param, in case of a stage.Stage object. :type param_attr: str """ stages_to_concat = [] insertion_point = None i = last_stage = 0 has_permanent_restrain = False first_simulate_param = None last_gcmc_block = None def is_restrained(param): return (("restrain" in param and param.restrain.val != "none") or bool(has_explicit_restraints(param))) for stage in stages: stage_param = getattr(stage, param_attr) if param_attr else stage try: if stage_param.should_skip.val: # don't let skipped stages break up otherwise consecutive # simulate stages if last_stage: last_stage = i i += 1 continue except AttributeError: pass name = stage_param.__NAME__ # TODO we can't check stage.AssignForcefield.NAME here because we can't # import module stage (would be circular). That's pretty strong # evidence that we should move these concatenation-related functions to # a stage_utils module if name == "assign_forcefield": has_permanent_restrain |= is_restrained(stage_param) if name in _PRODUCTION_SIMULATION_STAGES: break elif name == "simulate": # simulate stages must be adjacent to concatenate if last_stage and last_stage != i - 1: break # gcmc stages can only be concatenated if gcmc blocks are identical # across stages if "gcmc" in stage_param.keys(tag="setbyuser"): gcmc_param = stage_param.gcmc if (last_gcmc_block is not None and gcmc_param.val != last_gcmc_block.val): break else: gcmc_param = sea.Atom("none") last_gcmc_block = gcmc_param # conditions on restrain block to concatenate if first_simulate_param is None: # we use whole `stage_param` instead of the restraints # themselves to (partially) support both old-style "restrain" # and new-style "restraints" in "Concatenate" stage (single # "flavor" per stage); ideally this needs to be # revised/tightened at some point during or after DESMOND-10079 first_simulate_param = stage_param if restraints_incompatible(stage_param, first_simulate_param, has_permanent_restrain): break if insertion_point is None: insertion_point = i last_stage = i stages_to_concat.append(stage) i += 1 # the production stage can be either the last simulate stage or one of those # defined in _PRODUCTION_SIMULATION_STAGES. if we've reached the last stage # without breaking it means the production stage is a normal simulate stage. # In that case, we need to remove the production stage from the list of # stages to concatenate if i == len(stages) and stages_to_concat: stages_to_concat.pop() return stages_to_concat, insertion_point
[docs]def make_empty_restraints(existing='ignore') -> sea.Map: outcome = sea.Map() outcome["existing"] = existing outcome["new"] = sea.List() return outcome
[docs]def get_restrain(sm: sea.Map) -> sea.Sea: try: return sm.get_value("restrain") except KeyError: return sea.Atom("none")
[docs]def get_restraints(sm: sea.Map) -> sea.Map: try: return sm.get_value("restraints") except KeyError: return make_empty_restraints()
[docs]def get_restraints_xor_convert_restrain(param: sea.Map) -> sea.Map: """ Returns `restrains` or `restrain` (converted into `restraints` format) from the `param`. Raises `ValueError` if both are set. :param param: stage parameters :return: restraints block """ restrain = get_restrain(param) if has_explicit_restraints(param): if restrain.val != 'none': raise ValueError("Concatenate stage cannot include " "`restrain` and `restraints` simultaneously") else: return get_restraints(param) else: return _restraints_from_restrain(restrain)
[docs]def restraints_incompatible(param: sea.Map, initial_param: sea.Map, has_permanent_restrain: bool): """ Returns whether restraints parameters are compatible with switching during a concatenate stage. For compatibility the parameters has to differ from the initial ones by only a scaling factor (which can include zero). Furthermore, there can be no differences between restraints and initial restraints if `permanent_restrain` is truthy, as there is no way to selectively scale restraints. :param param: the param for a given stage :type param: `sea.Map` :param initial_param: parameters for the first stage :type initial_param: `sea.Map` :param has_permanent_restrain: whether or not there are restraints applied to all stages via the `permanent_restraints` mechanism :type has_permanent_restrain: bool :return: a message declaring how the restraints are incompatible, or an empty string if they are compatible :rtype: str """ param_restrain = get_restrain(param) initial_param_restrain = get_restrain(initial_param) have_restrain = (param_restrain.val != "none" or initial_param_restrain.val != "none") have_restraints = (has_explicit_restraints(param) or has_explicit_restraints(initial_param)) if have_restrain and have_restraints: return ("We cannot concatenate stages that mix restraints " "given via the `restraints` and `restrain` parameters") if have_restrain: current = _restraints_from_restrain(param_restrain) initial = _restraints_from_restrain(initial_param_restrain) else: current = get_restraints(param) initial = get_restraints(initial_param) return _check_restraints_compatibility( current=current, initial=initial, has_permanent_restrain=has_permanent_restrain)
[docs]def has_explicit_restraints(param: sea.Map): """ :param param: the param for a given stage :return: whether or not the `restraints` block has new or existing restraints """ if "restraints" in param: explicit_restraints = param.restraints has_new = "new" in explicit_restraints and explicit_restraints.new.val has_existing = ("existing" in explicit_restraints and explicit_restraints.existing.val != constants.EXISTING_RESTRAINT.IGNORE) return has_new or has_existing return False
[docs]def check_restrain_diffs(restrain, initial_restrain): """ See if the differences between two restrain blocks are concatenation-compatible, meaning they are both `sea.Map` objects and differ only by a force constant. :param restrain: the restrain block for a given stage :type restrain: `sea.Map` or `sea.List` :param initial_restrain: the restraints for the first stage :type initial_restrain: `sea.Map` or `sea.List` :return: a message declaring how the restraints are incompatible, or an empty string if they are compatible :type: str """ if restrain == initial_restrain: return "" def head_if_single(o): return o[0] if isinstance(o, sea.List) and len(o) == 1 else o restrain = head_if_single(restrain) initial_restrain = head_if_single(initial_restrain) if isinstance(restrain, sea.Map) and isinstance(initial_restrain, sea.Map): for restrain_diff in sea.diff(restrain, initial_restrain): for key in restrain_diff: if key not in ["force_constant", "fc", "force_constants"]: return ("We cannot change restraint parameters other than " "the force constant between integrators") return "" elif isinstance(restrain, sea.List) or isinstance(initial_restrain, sea.List): return ("We cannot change between lists of restraint parameters " "unless they are identical.") else: raise ValueError("restraints definition blocks expected to be " "`sea.List` or `sea.Map`")
def _check_restraints_compatibility(initial: sea.Map, current: sea.Map, has_permanent_restrain: bool) -> str: """ Returns whether the restrain parameters are compatible with switching during a concatenate stage. For compatibility, `current` has to differ from the `initial` by only a scaling factor (which can include zero). :param initial: preceding `restraints` block :type initial: `sea.Map` :param current: `restraints` block :type current: `sea.Map` :param has_permanent_restrain: whether or not there are restraints applied to all stages via the `permanent_restraints` mechanism :type has_permanent_restrain: bool :return: a message declaring how the restraints are incompatible, or an empty string if they are compatible :rtype: str """ def get(m, n): return m[n].val if n in m else None def is_none(r): return get(r, 'existing') == 'ignore' and not get(r, 'new') def is_retain(r): return get(r, 'existing') == 'retain' and not get(r, 'new') if current != initial and not is_retain(current): # there can be no difference between restraints # blocks if system has permanent restraints if has_permanent_restrain: return ("Subsequent simulate blocks cannot have differing " "restrain blocks when permanent restraints are used") # we cannot go from no restrain to some restrain if is_none(initial): return ("Subsequent simulate blocks cannot have restrain block " "unless the first simulate block or concatenate stage does") elif not is_none(current): # none is acceptable if current.existing != initial.existing: return ("Subsequent simulate blocks cannot have " "differing restraints") else: return check_restrain_diffs(current.new, initial.new) return "" def _restraints_from_restrain( old: Union[sea.Atom, sea.List, sea.Map]) -> sea.Map: """ Translates old-style restraints specification ("restrain") into equivalent new-style blurb. Current version is incomplete, limited to the features needed for the concatenation support. :param old: old-style "restrain" block (string, map or list) :return: equivalent new-style "restrains" block """ outcome = make_empty_restraints( existing='retain' if old.val == 'retain' else 'ignore') if old.val in ('none', 'retain'): pass elif isinstance(old, sea.Map): outcome["new"].append(old) # copies `old` elif isinstance(old, sea.List): outcome["new"].extend(old) # copies `old` else: raise ValueError("`restrain` block must be `none`, `retain`, " "`sea.Map` or `sea.List`") for blk in outcome["new"]: fc = blk["fc"] if "fc" in blk else blk.force_constant blk["force_constants"] = fc # copies `fc` return outcome PARAM = None # `sea.Map' object containing the whole job's msj setting
[docs]def msj2sea(fname, msj_content=None): """ Parses a file as specified by 'fname' or a string given by 'msj_content' (if both are given, the former will be ignored), and returns a 'sea.Map' object that represents the stage settings with a structure like the following:: stage = [ { <stage 1 settings> } { <stage 2 settings> } { <stage 3 settings> } ... ] Each stage's name can be accessed in this way: raw.stage[1].__NAME__, where 'raw' is the returned 'sea.Map' object. """ if not msj_content: msj_file = open(fname, "r") msj_content = msj_file.read() msj_file.close() raw = sea.Map("stage = [" + msj_content + "]") # User might set a stage as "stagename = {...}" by mistake. Raises a # meaningful exception when this happens. for s in raw.stage: if isinstance(s, sea.Atom) and s.val == "=": raise SyntaxError( "Stage name cannot be followed by the assignment operator: '='") stg = list(range(len(raw.stage)))[::2] for i in stg: try: s = raw.stage[i + 1] name = raw.stage[i].val.lower() s.__NAME__ = name except IndexError: raise SyntaxError("stage %d is undefined" % i + 1) stg.reverse() for i in stg: del raw.stage[i] return raw
[docs]def msj2sea_full(fname, msj_content=None, pset=""): raw = msj2sea(fname, msj_content) for i, e in enumerate(raw.stage): try: stage_cls = StageBase.stage_cls[e.__NAME__] except KeyError: raise ParseError("Unrecognized stage name: %s\n" % e.__NAME__) param = copy.deepcopy(stage_cls.PARAM.DATA) param.update(e, tag="setbyuser") param.__NAME__ = e.__NAME__ param.__CLS__ = stage_cls raw.stage[i] = param if pset: raw.stage.insert(0, sea.Atom("dummy")) pset = pset.split(chr(30)) for e in pset: i = e.find("=") if i <= 0: raise ParseError("Syntax error in setting: %s" % e) try: key = e[:i].strip() value = e[i + 1:].strip() except IndexError: raise ParseError("Syntax error in setting: %s" % e) if key == "" or value == "": raise ParseError("Syntax error in setting: %s" % e) raw.set_value(key, sea.Map("value = %s" % value).value.val, tag="setbyuser") del raw.stage[0] return raw
[docs]def parse_msj(fname, msj_content=None, pset=""): """ sea.update_macro_dict must be called prior to calling this function. """ try: global PARAM PARAM = msj2sea_full(fname, msj_content, pset) PARAM.stage.insert(0, sea.Atom("dummy")) except Exception as e: raise ParseError(str(e)) print_debug("All settings of this multisim job...") print_debug(PARAM) print_debug("All settings of this multisim job... End") # Constructs stage objects and their parameters. stg = [] error = "" for i, e in enumerate(PARAM.stage[1:], start=1): s = e.__CLS__() # Creates a stage instance. # handle backward-compatibility issues s.migrate_param(e) s.param = e # FIXME: How to deal with exceptions raised by the parsing and checking # functions? ev = s.check_param() if ev.err != "": error += "Value error(s) for stage[%d]:\n%s\n" % (i, ev.err) if ev.unchecked_map: error += "Unrecognized parameters for stage[%d]: %s\n\n" % ( i, ev.unchecked_map) stg.append(s) if error: raise ParseError(error) return stg
[docs]def write_msj(stage_list, fname=None, to_str=True): """ Given a list of stages, writes out a .msj file of the name 'fname'. If 'to_str' is True, a string will be returned. The returned string contains the contents of the .msj file. If 'to_str' is False and not file name is provided, then this function does nothing. """ if fname is None and to_str is False: return s = "" for stage in stage_list: s += stage.NAME + " {\n" s += stage.param.__str__(" ", tag="setbyuser") s += "}\n\n" if fname is not None: fh = open(fname, "w") print(s, file=fh) fh.close() if to_str: return s
[docs]def write_sea2msj(stage_list, fname=None, to_str=True): if fname is None and to_str is False: return s = "" for stage in stage_list: name = stage.__NAME__ s += name + " {\n" s += stage.__str__(" ", tag="setbyuser") s += "}\n\n" if fname is not None: fh = open(fname, "w") print(s, file=fh) fh.close() if to_str: return s
def _collect_inputfile_from_file_list(list_, fnames): for v in list_: if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "": fnames.append(v.val) elif isinstance(v, sea.Map): _collect_inputfile_from_file_map(v, fnames) elif isinstance(v, sea.List): _collect_inputfile_from_file_list(v, fnames) return fnames def _collect_inputfile_from_file_map(map, fnames): for k, v in map.key_value(): if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "": fnames.append(v.val) elif isinstance(v, sea.Map): _collect_inputfile_from_file_map(v, fnames) elif isinstance(v, sea.List): _collect_inputfile_from_file_list(v, fnames) return fnames def _collect_inputfile_from_list(list_, fnames): for v in list_: if isinstance(v, sea.Map): _collect_inputfile_from_map(v, fnames) elif isinstance(v, sea.List): _collect_inputfile_from_list(v, fnames) return fnames def _collect_inputfile_from_map(map, fnames): for k, v in map.key_value(): if (isinstance(v, sea.Atom) and k.endswith("_file") and isinstance(v.val, str) and v.val != ""): fnames.append(v.val) elif isinstance(v, sea.Map): if k.endswith("_file"): _collect_inputfile_from_file_map(v, fnames) else: _collect_inputfile_from_map(v, fnames) elif isinstance(v, sea.List): if k.endswith("_file"): _collect_inputfile_from_file_list(v, fnames) else: _collect_inputfile_from_list(v, fnames) return fnames
[docs]def collect_inputfile(stage_list): """ Returns a list of file names. """ fnames = [] for stage in stage_list: if not stage.param.should_skip.val: try: fnames.extend(stage.collect_inputfile()) except AttributeError: _collect_inputfile_from_map(stage.param, fnames) return fnames
[docs]class AslValidator(object): CTSTR = """hydrogen 1 0 0 0 1 0 999 V2000 -1.6976 2.1561 0.0000 C 0 0 0 0 0 0 M END $$$$ """ CT = None
[docs] def __init__(self): self.invalid_asl_expr = []
[docs] def is_valid(self, asl): if AslValidator.CT is None: import schrodinger.structure as structure AslValidator.CT = next( structure.StructureReader.fromString(AslValidator.CTSTR, format="sd")) import schrodinger.structutils.analyze as analyze try: analyze.evaluate_asl(AslValidator.CT, asl) except mm.MmException: return False return True
[docs] def validate(self, a): if isinstance(a, sea.Atom): v = a.val if (isinstance(v, str) and v[:4].lower() == "asl:" and not self.is_valid(v[4:])): self.invalid_asl_expr.append(v) elif isinstance(a, sea.Sea): a.apply(lambda x: self.validate(x))
[docs]def validate_asl_expr(stage_list): """ Validates all ASL expressions that start with the "asl:" prefix. """ validator = AslValidator() for stage in stage_list: if not stage.param.should_skip.val: stage.param.apply(lambda x: validator.validate(x)) return validator.invalid_asl_expr
# - Registered functions should share this prototype: foo( stage, PARAM, arg ), # and should return a boolean value, where PARAM is a sea.Map object in the # global scope of this module. _operator = {}
[docs]def reg_checking(name, func): _operator[name] = func