Source code for schrodinger.application.desmond.starter.generator.common

import copy
import os
from pathlib import Path
from typing import TYPE_CHECKING
from typing import List
from typing import Optional

from schrodinger.application.desmond import cmj
from schrodinger.application.desmond import launch_utils
from schrodinger.application.desmond import stage
from schrodinger.application.desmond import util
from schrodinger.application.desmond.constants import FEP_TYPES
from schrodinger.application.desmond.constants import UiMode
from schrodinger.application.desmond.constants import SIMULATION_PROTOCOL
from schrodinger.application.desmond import multisim
from schrodinger.application.desmond.starter.ui import cmdline
from schrodinger.application.desmond.starter.ui.cmdline import FepArgs
from schrodinger.utils import sea

if TYPE_CHECKING:
    from schrodinger.application.scisol.packages.fep import graph  # noqa: F401


[docs]def is_fmp(input_fname: str) -> bool: """ Return True if the input_fname is for an fmp file. Returns False otherwise. """ return Path(input_fname).suffix.lower() in ['.fmp', '.pkl']
[docs]def find_fmpdb_file(args: cmdline.BaseArgs) -> Optional[str]: """ Tries to find the fmpdb file's name if it's needed. If it cannot find, returns `None` and issues a warning message. """ filename = None inp_file = args.inp_file if inp_file is None: # Both restarting and extending (running simulations for longer) the # previous job won't set the inp_file argument, then we require the # .fmpdb file to be in the CWD and be named after "<jobname>_out.fmpdb". # If the .fmpdb file is not found, warn. filename = args.JOBNAME and (args.JOBNAME + '_out.fmpdb') else: # Two cases: # 1. This is a graph-expansion job: inp_file is a .fmp file. # Find the .fmpdb file from the .fmp file. # 2. This is a from-scratch new job: inp_file is a .mae file or a .fmp # file without an associated .fmpdb file. # No need for a .fmpdb file and warnings. if is_fmp(inp_file): # Case 1 from schrodinger.application.scisol.packages.fep import graph # noqa: F811 g = graph.Graph.deserialize(inp_file) filename = g.fmpdb and g.fmpdb.filename if filename is None: # Case 2 return None if filename is None: # Tried, but cannot figure out the file name. print("WARNING: Cannot figure out the .fmpdb file name.") elif not os.path.isfile(filename): print("WARNING: .fmpdb file not found: %s" % filename) else: return filename
[docs]def prepare_files_and_command_for_restart(args: cmdline.BaseArgs) -> List[str]: """ Return a command for launching the restart multisim job. Exit if the multisim stage could not be found. :param args: Command line arguments. """ cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number( args.checkpoint) rst_whole = rst_stage_idx is not None and rst_stage_idx > 0 engine = launch_utils.read_checkpoint_file(cpt_fname) if not rst_stage_idx: rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine) multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine) if not multisim_stage_numbers: raise RestartException("ERROR: multisim stage not found.") launch_utils.validate_restart_stage(engine, rst_stage_idx) stage_data_fnames = launch_utils.prepare_multisim_files_for_restart( engine, multisim_stage_numbers, rst_stage_idx, rst_whole, skip_traj=args.skip_traj) host = f'{args.HOST}:{args.ppj}' if ':' not in args.HOST else args.HOST cmd = launch_utils.prepare_command_for_restart(engine, stage_data_fnames, args.HOST, args.SUBHOST, cpt_fname, maxjob=args.ppj, jobname=args.JOBNAME, msj=args.msj, rst_stage_idx=rst_stage_idx, rst_whole=rst_whole) forcefield = None cmd += launch_utils.additional_command_arguments( stage_data_fnames, args.RETRIES, args.WAIT, args.LOCAL, args.DEBUG, args.TMPDIR, forcefield, args.OPLSDIR, args.NICE, args.SAVE) return cmd
[docs]def prepare_files_and_command_for_fep_restart_extend( args: FepArgs, edges: List[str], launcher_stage_name: str = stage.FepLauncher.NAME ) -> (List[str], List[str]): stage_data_fnames = [] cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number( args.checkpoint) engine = launch_utils.read_checkpoint_file(cpt_fname) rst_whole = False multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine) if not multisim_stage_numbers: raise RestartException("ERROR: multisim stage not found.") if args.mode == UiMode.EXTEND: from schrodinger.application.scisol.packages.fep import utils from schrodinger.application.scisol.packages.fep import graph # noqa: F811 rst_stage_idx = multisim_stage_numbers[-1] g = graph.Graph.deserialize(f"{engine.jobname}_out.fmp") if g.fep_type in [FEP_TYPES.ABSOLUTE_BINDING, FEP_TYPES.SOLUBILITY]: sim_protocols = { utils.get_ligand_node(e).short_id: e.simulation_protocol for e in g.edges_iter() } else: sim_protocols = { "_".join(e.short_id): e.simulation_protocol for e in g.edges_iter() } # Modifies the checkpoint file. fep_launcher_stage = launch_utils.find_stage(engine.stage, launcher_stage_name) current = cmj.ENGINE cmj.ENGINE = engine fep_launcher_stage.restart_edges(edges, sim_protocols=sim_protocols) cmj.ENGINE = current cpt_fname = "extend_%s" % os.path.basename(args.checkpoint) engine.write_checkpoint(cpt_fname) # Modifies the msj file. main_msj = _update_input_graph_file_param(args, engine) args.msj = f"{args.JOBNAME}.extend.msj" main_msj.write(args.msj) fep_launcher_dispatch = main_msj.get(f"{launcher_stage_name}.dispatch") extend_stage_nums = dict() for protocol_name, job in fep_launcher_dispatch.items(): extend_stage_nums[protocol_name] = _write_extend_msjs( args, fep_launcher_dispatch[protocol_name], protocol_name) main_msj.put(f"{launcher_stage_name}.dispatch", sea.Map(fep_launcher_dispatch)) main_msj.put(f"{launcher_stage_name}.restart", sea.Map(extend_stage_nums)) main_msj.write(args.msj) elif args.mode == UiMode.RESTART: rst_whole = rst_stage_idx is not None and rst_stage_idx > 0 if not rst_stage_idx: rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine) launch_utils.validate_restart_stage(engine, rst_stage_idx) main_msj = _update_input_graph_file_param(args, engine) args.msj = f"{args.JOBNAME}.restart.msj" main_msj.write(args.msj) stage_data_fnames.extend( launch_utils.prepare_multisim_files_for_restart( engine, multisim_stage_numbers, rst_stage_idx, rst_whole, skip_traj=args.skip_traj)) stage_data_fnames.extend( _prepare_mapper_stages_for_restart(engine, rst_stage_idx)) # Deduplicate names stage_data_fnames = list(set(stage_data_fnames)) cmd = launch_utils.prepare_command_for_restart(engine, stage_data_fnames, args.HOST, args.SUBHOST, cpt_fname, maxjob=args.maxjob, jobname=args.JOBNAME, msj=args.msj, rst_stage_idx=rst_stage_idx, rst_whole=rst_whole) return cmd, stage_data_fnames
def _prepare_mapper_stages_for_restart(engine: cmj.Engine, rst_stage_idx: int) -> List[str]: """ If the FepMapperReport or FepMapperCleanup stage is present after the restart stage, include the FepMapper stage data when restarting the job. :param engine: Represents the current job state. :param rst_stage_idx: The restart stage index. :return: List of filenames to be used for restarting the job. """ stage_data_fnames = [] for stg in engine.stage[rst_stage_idx:]: if stg.NAME in [ stage.FepMapperReport.NAME, stage.FepMapperCleanup.NAME ]: mapper_stage_number = ( launch_utils.find_stage_number(engine.stage, stage.FepMapper.NAME) or launch_utils.find_stage_number(engine.stage, stage.ProteinFepMapper.NAME) or launch_utils.find_stage_number( engine.stage, stage.CovalentFepMapper.NAME)) - 1 # find_stage_number returns a 1-based index, and engine.stage's # first element is a primer stage which should be uncounted if mapper_stage_number < rst_stage_idx: stage_data_fnames.append( f"{engine.jobname}_{mapper_stage_number}-out.tgz") break return stage_data_fnames def _update_input_graph_file_param(args: FepArgs, engine: cmj.Engine) -> multisim.Msj: """ Add input_graph_file to main msj for certain stages """ if args.msj: main_msj = multisim.parse(args.msj) else: main_msj = multisim.parse(string=engine.msj_content) # Use the previous out.fmp file as input graph STAGES = [ "vacuum_report", "fep_mapper_report", "fep_mapper_cleanup", "fep_absolute_binding_analysis", "fep_solubility_analysis" ] out_fmp_name = f"{engine.jobname}_out.fmp" if os.path.isfile(out_fmp_name): for stage_name in STAGES: for s in main_msj.find(stage_name): s.put("input_graph_file", out_fmp_name) return main_msj def _write_extend_msjs(args: FepArgs, job: List[List[str]], protocol_name="default") -> List[int]: """ Write the .extend.msj and return the stage number to restart from for each msj """ extend_stage_nums = [] # Order of legs has to be same as order of legs in dispatch # avoid iterating over list while mutating it for job_args in copy.deepcopy(job): jobname = job_args[job_args.index("-JOBNAME") + 1] leg_type = util.get_leg_type_from_jobname(jobname) leg_name = util.get_leg_name_from_jobname(jobname) extend_stage_num = _write_extend_msj(args, job, protocol_name, leg_type, leg_name) if extend_stage_num: extend_stage_nums.append(extend_stage_num) return extend_stage_nums def _write_extend_msj(args: FepArgs, job: List[List[str]], protocol_name: str, leg_type: str, leg_name: str) -> \ Optional[int]: """ 1) Modify the dispatch command for the given protocol+leg in the main msj to point to the new extend msjs. 2) Modify the added time of the extend stage in the subjob msj. 3) Return the extend stage number. This will only modify the msj if the leg existed in the original msj (for example, it may skip 'vacuum'). """ from schrodinger.application.desmond import multisim sim_time = args.get_time_for_leg(leg_type) if sim_time is None: return None leg_idx = _find_leg_idx(leg_name, job) if leg_idx is None: return if sim_time < 1.0: # skip extending a given leg if simulation time is < 1 ps. job.pop(leg_idx) return if protocol_name == SIMULATION_PROTOCOL.DEFAULT: new_msj_fname = f"{args.JOBNAME}_{leg_type}.extend.msj" elif protocol_name in [ SIMULATION_PROTOCOL.CHARGED, SIMULATION_PROTOCOL.FORMALCHARGED ]: new_msj_fname = f"{args.JOBNAME}_{leg_type}_chg.extend.msj" else: new_msj_fname = f"{args.JOBNAME}_{leg_type}_{protocol_name}.extend.msj" msj_fname = None for i, e in enumerate(job[leg_idx]): if (e == "-m"): msj_fname = job[leg_idx][i + 1] job[leg_idx][i + 1] = new_msj_fname if msj_fname and os.path.isfile(msj_fname): subjob_msj = multisim.parse(msj_fname) subjob_msj.put("desmond_extend.added_time", sim_time) subjob_msj.write(new_msj_fname) else: raise RestartException(f"ERROR: File not found: '{msj_fname}'") return subjob_msj.find_stages("desmond_extend")[0].STAGE_INDEX def _find_leg_idx(leg: str, job: List) -> Optional[int]: """ :param: job List of commands. """ for idx, cmd in enumerate(job): jobname = cmd[cmd.index("-JOBNAME") + 1] if util.get_leg_name_from_jobname(jobname) == leg: return idx
[docs]class RestartException(Exception): pass