Source code for schrodinger.application.desmond.mxmd.mxmd_cleanup

"""
A post-simulation clean-up script for Mixed Solvent (MxMD) workflow. This script
extracts the occupancy data for all co-solvent subjob and combines them into
occupancy maps. It then clusters and identifies Hotspots from these maps.

As an output, a Maestro Project file (.prjzip) and a 'results' directory are
written. The directory contains CNS maps for all co-solvent probes and a Maestro
structure of the last snapshots for all co-solvent subjob. The command should be
run from the base directory of the mixed solvent job.
"""

import argparse
import os
import shutil
import tarfile
from collections import defaultdict
from io import BytesIO
from itertools import combinations
from itertools import product
from pathlib import Path
from typing import Callable
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union

import networkx as nx
import numpy as np
from scipy.cluster.hierarchy import fcluster
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import cdist

from schrodinger import project
from schrodinger import structure
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import struc
from schrodinger.application.desmond.cmj import Engine
from schrodinger.application.desmond.cmj import JobStatus
from schrodinger.application.desmond.cns_io import write_cns
from schrodinger.application.desmond.mxmd.mxmd_system_builder import \
    BUILTIN_PROBE_NAMES
from schrodinger.application.desmond.stage.app.mxmd import \
    normalize_probe_occupancy_grid
from schrodinger.infra import mm
from schrodinger.infra import mmproj
from schrodinger.infra import mmsurf
from schrodinger.utils import log

logger = log.get_output_logger("mxmd_cleanup.py")

colors = [(255, 166, 115), (1, 111, 244), (208, 204, 28), (147, 116, 255),
          (57, 190, 62), (152, 50, 185), (141, 217, 101), (217, 61, 189),
          (173, 171, 0), (17, 78, 171), (197, 205, 91), (2, 125, 217),
          (180, 0, 6), (79, 220, 182), (247, 31, 130), (0, 150, 81),
          (160, 3, 103), (0, 105, 39), (255, 88, 123), (1, 160, 144),
          (163, 11, 77), (62, 91, 25), (220, 169, 255), (134, 119, 0),
          (133, 194, 255), (184, 103, 0), (90, 111, 162), (100, 81, 25),
          (104, 68, 127), (131, 56, 98)]

HOTSPOT_SURF_COLOR = (255, 255, 255)  # white
BUILTIN_COLOR_DICT = {
    p: colors[i % len(colors)] for i, p in enumerate(BUILTIN_PROBE_NAMES)
}

# to maintain the same color scheme previously used
BUILTIN_COLOR_DICT.update({
    "acetone": (100, 100, 225),  # light blue
    "acetonitrile": (234, 130, 50),  # orange
    "imidazole": (234, 173, 234),  # plum
    "isopropanol": (30, 30, 225),  # blue
    "nmethylacetamide": (165, 42, 42),  # cyan
    "pyrimidine": (225, 30, 225)  # purple
})


def _get_color_for_probe(probe_name: str,
                         probe_idx: int) -> Tuple[int, int, int]:
    """
    Get the color that should be used for the given probe.

    :param probe_name: The name of the probe (e.g. acetonitrile)
    :param probe_idx: The index of the probe in the list of probes used for the
        simulation. Used to determine color if the probe_name is not a builtin
    :return: an rgb tuple, with each value ranging from 0-255
    """
    color = BUILTIN_COLOR_DICT.get(probe_name)
    if not color:
        color_idx = len(BUILTIN_COLOR_DICT) + probe_idx
        color = colors[color_idx % len(colors)]
    return color


[docs]def set_isosurface(pt: "project.Project", cns_file: str, row_index: int, cutoff: float, color: tuple, surf_name: str, surf_comment: str): """ When writing a project table, this function will setup all isosurface properties. """ ivol, grid_center, cns_map_type, map_type_comment = \ mmsurf.mmvisio_read_cns(cns_file) mmproj.mmproj_index_entry_add_volume(pt.handle, row_index, ivol) mmsurf.mmsurf_set_gen_isolines(True) isur = mmsurf.mmsurf_new_isosurface(ivol, cutoff, False, 0, grid_center[0], grid_center[1], grid_center[2], 0) mmsurf.mmsurf_set_name(isur, surf_name) mmsurf.mmsurf_set_rgb_color(isur, color) vol_name = mmsurf.mmvol_get_name(ivol) mmsurf.mmsurf_set_volume_name(isur, vol_name) mmsurf.mmsurf_set_scheme_volume_name(isur, vol_name) mmsurf.mmsurf_set_style(isur, mmsurf.MMSURF_STYLE_SOLID) mmsurf.mmsurf_set_transparency(isur, 50) mmsurf.mmsurf_set_visibility(isur, True) mmsurf.mmsurf_set_comment(isur, surf_comment) mmsurf.mmsurf_set_surface_type(isur, 'Desmond MxMD') mmproj.mmproj_index_entry_add_surface(pt.handle, row_index, isur, None)
Float3 = Tuple[float, float, float]
[docs]class Spot: """ Occupancy map for each probe is clustered and discretized into separate occupancy clusters called 'Spots'. Each Spot object contains the grid's coordinates, their occupancy values and the probe CTs. """
[docs] def __init__(self, *, grid_points: Set[Float3], grid: np.ndarray, probe_name: str, grid_spacing: Float3, probe_mols: Set[structure._Molecule]): self.grid = grid self.grid_points = grid_points self._grid_spacing = grid_spacing self._probe_name = probe_name self._probe_mols = probe_mols
def __str__(self): return f"{self.probe_name}_{self.volume} ų" @property def volume(self): return len(self.grid_points) * np.product(self._grid_spacing) @property def probe_name(self): return self._probe_name @property def probe_mols(self): return self._probe_mols
[docs]class Hotspot: """ A `Hotspot` refers to a collection of grid points, occupied by two or more `Spots`. """
[docs] def __init__(self, box_size: Float3, center: Float3, grid_spacing: Float3): self._spots: List[Spot] = [] self._grid_points: Set[Float3] = set() self._grid_spacing = grid_spacing self._box_size = box_size self._center = center self._points_sigmas = defaultdict(list) self._probe_cts = defaultdict(set)
def __str__(self): return (f"Hotspot: contains {len(self._spots)} spots, " f"totaling {self.volume} ų and with MxMD score of " f"{self.score(np.sum)}")
[docs] def add_spot(self, spot: Spot): """ `Spot` must have the same grid_spacing """ self._spots.append(spot) self._grid_points = self._grid_points | spot.grid_points for gp in spot.grid_points: self._points_sigmas[gp].append(spot.grid[gp]) self._probe_cts[spot.probe_name] |= spot.probe_mols
@property def volume(self) -> float: return len(self._grid_points) * np.product(self._grid_spacing)
[docs] def score(self, func) -> int: """ This function is how to treat occupancy value of overlapping grid points. :param func: can be either `np.mean`, `np.max` or `np.sum` or `np.min` """ return int(sum(func(v) for v in self._points_sigmas.values()))
[docs] def get_probes_structure(self, hotspot_id=None, log=False) -> structure.Structure: nmols = sum([len(mols) for mols in self._probe_cts.values()]) if nmols == 0: if log: logger.info(' \\-> WARNING: Did not find any probe molecules ' 'to fit the hotspot occupancy.') return else: if log: logger.info(f' \\-> Found {nmols} probe molecules to fit the ' 'hotspot occupancy.') # Extract the molecules from thir respected structures and merge them # into a new CT. ct = structure.create_new_structure() for mols in self._probe_cts.values(): for m in mols: ct = ct.merge(m.structure.extract(m.getAtomIndices())) if hotspot_id is not None: ct.title = f'Hotspot_{hotspot_id} probe molecules' return ct
[docs] def write_mae(self, filename: str): """ Write probe CTs that correspond to this hotspot. """ ct = self.get_probes_structure(log=True) if ct is not None: ct.write(filename)
[docs] def write_cns(self, filename: str, crop_size: float = None, func: Callable = np.sum): """ Write hotspot grid to cns format. :param crop_size: to reduce the size of he CNS files that is written, crop the grid to specified size. The sigma values for each grid point will be reduced by `func` """ hotspot_grid = np.zeros(self._spots[0].grid.shape) for k, v in self._points_sigmas.items(): v = func(v) hotspot_grid[k] = v box_size = self._box_size if crop_size and crop_size < self._box_size[0]: grid_size = crop_size / self._grid_spacing[0] crop_edge = np.floor( (hotspot_grid.shape[0] - grid_size) / 2.).astype(np.int) hotspot_grid = hotspot_grid[crop_edge:-crop_edge, crop_edge:-crop_edge, crop_edge:-crop_edge] box_size = [crop_size] * 3 write_cns(hotspot_grid, box_size, self._grid_spacing, filename, center=self._center)
[docs]def read_checkpoint_file(chkpt_file: Union[Path, str]) -> Optional[Engine]: """ Read multisim checkpoint file. :param chkpt_file: Checkpoint file name. :return: Engine object: :raise: ValueError if the checkpoint file could not be read. """ try: with open(chkpt_file, 'rb') as fh: return Engine.deserialize(fh) except EnvironmentError: raise ValueError( f"Could not read checkpoint file: {chkpt_file}. " "Please note that the script is only supposed to be run " "from the result directory.") except EOFError: raise ValueError( f"Could not read checkpoint file: {chkpt_file}. File is corrupt.")
[docs]def get_stage(job: Engine, stage_name: str) -> Tuple[Optional[int], Optional[str]]: """ Return the index of stage with name stage_name. :param job: The checkpoint file object. :return: Index of stage_name or None if not found. Job directory or None if not found. """ stage_idx = None job_dir = None for i, s in enumerate(job.stage, start=0): if s.NAME == stage_name: stage_idx = i if s.jobs: job_dir = os.path.basename(s.jobs[0].dir) return stage_idx, job_dir
[docs]def split_into_spots(probe_name: str, probe_data: np.array, probe_mols: List[structure._Molecule], grid_spacing: Float3, grid_center: Float3, box_size: Float3, sigma: float, cluster_cutoff: float) -> List[Spot]: """ Given an occupancy grid for a single probe, cluster these points and create Spot objects from them. """ spots = [] # grid indices above the sigma value indices = list(zip(*np.nonzero(probe_data > sigma))) if not indices: return [] labels = [1] if len(indices) > 1: gt = cluster_cutoff / grid_spacing[0] z = linkage(indices, method='single', metric='euclidean') labels = fcluster(z, t=gt, criterion='distance') for c in set(labels): points = set([indices[ei] for ei, e in enumerate(labels) if e == c]) clust_data = np.zeros(probe_data.shape) for index in points: clust_data[index] += probe_data[index] # Find which grid points overlap with probe_cts molecules # convert grid points to system coordinates points_xyz = ((np.asarray(list(points)) * np.asarray(grid_spacing)) - (np.asarray(box_size) / 2) + np.asarray(grid_center)) probes_to_keep = set() for pr in probe_mols: dist = cdist(np.array([a.xyz for a in pr.atom]), points_xyz) # use grid-spacing values for distance cutoff if np.min(dist) <= grid_spacing[0]: probes_to_keep.add(pr) spots.append( Spot(grid_points=points, grid=clust_data, probe_name=probe_name, grid_spacing=grid_spacing, probe_mols=probes_to_keep)) return spots
[docs]class CleanUp: """ Class for cleaning up mixed solvent subjobs. """
[docs] def __init__(self, chkpt_file: str, sigma: float = 20.0, cluster_cutoff: float = 3.0, ligand_xyz: np.array = None): self.chkpt_file = Path(chkpt_file).absolute() self.exec_dir = Path(os.path.dirname(chkpt_file)) self.sigma = sigma self.cluster_cutoff = cluster_cutoff self.main_job: Engine = read_checkpoint_file(self.chkpt_file) self.jobname = self.main_job.jobname self.subjob_names = self.get_subjob_names() self.grid_data = defaultdict(list) self.ct_data = defaultdict(list) self.ct_probe = defaultdict(list) self.ref_ct = None self.center: Float3 = None self.box_size: Float3 = None self.grid_spacing: Float3 = None self.ligand_xyz = ligand_xyz
[docs] def run(self): """ Run the cleanup workflow. """ logger.info("Start cleanup...") self.create_results_directory() self.read_archive_data() self.set_ref_ct() self.gen_normgrid_data() self.write_maestro_project() logger.info("Cleanup done.")
[docs] def get_subjob_names(self) -> List[str]: """ :return: A list of subjob names. """ from schrodinger.application.desmond import stage for s in self.main_job.stage: if s.NAME != stage.MixedSolventSetup.NAME: continue return [ subjob.tag for subjob in s.filter_jobs(status=[JobStatus.SUCCESS]) ]
[docs] def create_results_directory(self): """ Create directory in which all data results will be written to. """ self.result_dir = Path(f'{self.jobname}_results') if os.path.exists(self.result_dir): shutil.rmtree(self.result_dir) os.mkdir(self.result_dir)
[docs] def read_archive_data(self): """ Copy CNS and raw file from the subjob analysis' stage. """ from schrodinger.application.desmond import stage jobname = self.main_job.jobname multisim_stage_num, _ = get_stage(self.main_job, stage.Multisim.NAME) for subjob_name in self.subjob_names: subjob_base = os.path.join(f'{jobname}_{multisim_stage_num}', f'{jobname}_{subjob_name}') subjob_chkp_fname = self.exec_dir / f'{subjob_base}-multisim_checkpoint' if not os.path.exists(subjob_chkp_fname): logger.info(f"Skipping {subjob_chkp_fname} as it is not " "included in the archive.") continue subjob_chkp = read_checkpoint_file(subjob_chkp_fname) subjob_analysis_stage_num, _ = get_stage(subjob_chkp, 'mixed_solvent_analysis') tar_fn = self.exec_dir / f'{subjob_base}_{subjob_analysis_stage_num}-out.tgz' if os.path.exists(tar_fn): with tarfile.open(tar_fn) as fh: for name in fh.getnames(): if not name.endswith('.raw'): continue # read the .raw file from the archive member = fh.getmember(name) contents = BytesIO() contents.write(fh.extractfile(member).read()) contents.seek(0) data = np.load(contents) # read protein mae file member = fh.getmember(name.replace('.raw', '.mae')) ct_str = fh.extractfile(member).read().decode() ct = structure.Structure(mm.mmct_ct_from_string(ct_str)) probe_name = ct.property[constants.MXMD_COSOLVENT_PROBE] self.prepare_ct(ct, probe_name) self.ct_data[probe_name].append(ct) self.center = (ct.property[constants.MXMD_CENTER_X], ct.property[constants.MXMD_CENTER_Y], ct.property[constants.MXMD_CENTER_Z]) self.box_size = tuple( [ct.property[constants.MXMD_BOX_LENGTH]] * 3) self.grid_spacing = tuple( [ct.property[constants.MXMD_GRID_SPACING]] * 3) # Store the grid data self.grid_data[probe_name].append(data) # Read cosolvent mae try: member = fh.getmember( name.replace('-out.raw', '-probes.mae')) ct_str = fh.extractfile(member).read().decode() cosolvent_ct = structure.Structure( mm.mmct_ct_from_string(ct_str)) self.ct_probe[probe_name] += list( cosolvent_ct.molecule) except KeyError: # to handle pre 20-3 jobs logger.info( f"Not found: {name.replace('-out.raw', '-probes.mae')}" ) pass
[docs] def gen_normgrid_data(self): """ Generate normalized occupancy data """ self.normgrid_data = {} for probe, data in self.grid_data.items(): grid = np.zeros(data[0].shape) for g in data: grid += g grid /= len(data) self.normgrid_data[probe] = normalize_probe_occupancy_grid(grid)
[docs] def write_cns_mae_files(self) -> List[str]: """ Write CNS and Maestro files for each probe type. :return: List of cns files for each probe """ cns_files = [] for probe, data in self.grid_data.items(): logger.info(f"Aggregating grid data for {probe}.") cns_fn = str(self.result_dir / f'{probe}.cns') grid = self.normgrid_data[probe] write_cns(grid, self.box_size, self.grid_spacing, cns_fn, center=self.center) cns_files.append(cns_fn) mae_fn = self.result_dir / f'{probe}.mae' with structure.StructureWriter(str(mae_fn)) as writer: for ct in self.ct_data[probe]: writer.append(ct) logger.info(f'\tCNS and {len(self.ct_data[probe])} snapshots in ' f'MAE file are written to {self.result_dir} directory.') return cns_files
[docs] def write_hotspot_files(self) -> Tuple[List[str], List[str], List[str]]: """ Write CNS files and return a list of cns and mae filenames with comments about hotspot details. """ cosolvent_spots = { p: split_into_spots(p, normgrid, probe_mols=self.ct_probe[p], grid_spacing=self.grid_spacing, grid_center=self.center, box_size=self.box_size, sigma=self.sigma, cluster_cutoff=self.cluster_cutoff) for p, normgrid in self.normgrid_data.items() } g = nx.Graph() for pi, pj in combinations(self.normgrid_data.keys(), 2): for spot_i, spot_j in product(cosolvent_spots[pi], cosolvent_spots[pj]): overlap = spot_i.grid_points & spot_j.grid_points if overlap: logger.info(f"Found overlapping Spots: {spot_i} | {spot_j}") g.add_edge(spot_i, spot_j) hotspots = [] for c in nx.connected_components(g): sub_graph = g.subgraph(c) h = Hotspot(self.box_size, self.center, self.grid_spacing) for s in sub_graph.nodes(): h.add_spot(s) hotspots.append(h) hotspots.sort(key=lambda x: x.score(func=np.sum), reverse=True) cns_files, mae_files, comments = [], [], [] for ih, h in enumerate(hotspots): comment = f'{ih}: {h}' logger.info(comment) cns_fn = str(self.result_dir / f'hotspot_{ih}.cns') mae_fn = str(self.result_dir / f'hotspot_{ih}.mae') h.write_cns(cns_fn, crop_size=self._solute_size + 10) h.write_mae(mae_fn) cns_files.append(cns_fn) mae_files.append(mae_fn) comments.append(comment) self.hotspots = hotspots logger.info( f"Total Druggability Score: {sum(h.score(func=np.sum) for h in hotspots)}" ) if self.ligand_xyz is not None: self.check_ligand_overlap(hotspots) return cns_files, mae_files, comments
[docs] def check_ligand_overlap(self, hotspots: List[Hotspot]): for ih, h in enumerate(hotspots): grid = np.array(list(h._grid_points)) box_size = np.array(self.box_size) grid_spacing = np.array(h._grid_spacing) grid_xyz = (grid * grid_spacing[0]) - (box_size / 2.) + h._center dist = cdist(grid_xyz, self.ligand_xyz) if np.min(dist) <= 1: logger.info(f'Ligand(s) overlaps with HOTSPOT {ih}, ' f'@ {np.min(dist):.3f}: {h}')
[docs] def set_ref_ct(self): """ Read one of the output structures and extract the original input coordinates. """ for probe in self.ct_data: if self.ct_data[probe]: ct = self.ct_data[probe][0] break ct = struc.get_reference_ct(ct) ref_pos = ct.getXYZ() self._solute_size = np.max( np.max(ref_pos, axis=0) - np.min(ref_pos, axis=0)) self.prepare_ct(ct) ct.title = self.jobname + '_input' self.ref_ct = ct self.ref_ct.write( os.path.join(self.result_dir, f'{self.jobname}-input.mae'))
[docs] def write_maestro_project(self): """ Write a maestro prj table containing a summary of the results. """ pt_name = f'{self.jobname}.prj' if os.path.exists(pt_name): logger.info(f"Removing existing Maestro project {pt_name} to " "create a new one.") shutil.rmtree(pt_name) mmproj.mmproj_initialize(mm.MMERR_DEFAULT_HANDLER) mmsurf.mmsurf_initialize(mm.MMERR_DEFAULT_HANDLER) mmsurf.mmvol_initialize(mm.MMERR_DEFAULT_HANDLER) mmsurf.mmvisio_initialize(mm.MMERR_DEFAULT_HANDLER) probe_cns_files = self.write_cns_mae_files() hotspot_cns_files, hotspot_mae_files, hotspot_comments = \ self.write_hotspot_files() handle = mmproj.mmproj_project_new(pt_name) pt = project.Project(project_handle=handle) prj_row = pt.importStructure(self.ref_ct) prj_row.includeOnly() for idx, cns_file in enumerate(probe_cns_files): _, probe = os.path.split(cns_file) probe = probe.replace('.cns', '') name = f'{probe.upper()}_NORM_ISOSURFACE' comment = f"{probe.upper()} Occupancy Map" row_index = prj_row.index color = _get_color_for_probe(probe, idx) set_isosurface(pt, cns_file, row_index, self.sigma, color, name, comment) for ih, (cns_file, comment) in enumerate(zip(hotspot_cns_files, hotspot_comments)): name = f'Hotspot_{ih:02}' row_index = prj_row.index color = HOTSPOT_SURF_COLOR set_isosurface(pt, cns_file, row_index, 1, color, name, comment) for ih, h in enumerate(self.hotspots): ct = h.get_probes_structure(ih) if ct is not None: pt.importStructure(ct) pt.update() pt.close() pt_zip = project.zip_project(os.path.join(os.getcwd(), pt_name), os.getcwd()) if pt_zip: shutil.rmtree(pt_name) logger.info(f'Maestro Project {pt_name}zip is written.') else: raise ValueError(f'Error writing {pt_name}zip.')
[docs] def prepare_ct(self, ct: structure.Structure, probe: str = ''): """ Change structure title and remove trajectory and hierarchy info. :param ct: Structure to modify. :param probe: If specified, the name of the probe. Otherwise use the jobname. This is the default. """ ct.title = f'receptor_{probe}_{len(self.ct_data[probe])}' \ if probe else self.jobname + '_receptor' props_to_delete = [ 's_m_original_cms_file', 's_chorus_trajectory_file', 's_m_subgroup_title', 's_m_subgroupid', 'b_m_subgroup_collapsed', constants.CT_TYPE ] + list(constants.SIM_BOX) struc.delete_structure_properties(ct, props_to_delete)
[docs]def parse_cmd(cmdline: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser( description=__doc__, add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-h', '-help', action='help', default=argparse.SUPPRESS, help="Show this help message and exit.") parser.add_argument('-sigma', default=20.0, type=float, help="Isovalue to set in maestro project table.") parser.add_argument('-cluster-cutoff', metavar='CUTOFF', default=3.0, type=float, help='Threshold used for clustering probe spots.') parser.add_argument('checkpoint', help="Provide the main checkpoint file.") # Ligand maestro file to check if any of the hotspots overlap with it. # Supports multiple ct inputs. Log file will report which hotspot(s) # overlaps with the ligand coordinates. parser.add_argument('-ligand', default=None, help=argparse.SUPPRESS) args = parser.parse_args(cmdline) return args
[docs]def main(cmdline=None): args = parse_cmd(cmdline) ligand_xyz = None if args.ligand: ligands = list(structure.StructureReader(args.ligand)) ligand_xyz = np.concatenate([lig.getXYZ() for lig in ligands]) log.info(f'Reading {len(ligands)} ligands, {ligand_xyz.shape[0]} ' 'atoms total.') if args.checkpoint: cleaner = CleanUp(args.checkpoint, sigma=args.sigma, cluster_cutoff=args.cluster_cutoff, ligand_xyz=ligand_xyz) cleaner.run()
if __name__ == '__main__': main()