Source code for schrodinger.application.desmond.packages.restraint.utils

"""
Module for utilities for restraint generation.

Copyright Schrodinger, LLC. All rights reserved.
"""

import base64
import dataclasses
import json
from collections import Counter
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np

from schrodinger import structure
from schrodinger.application.desmond import cms
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import msys
from schrodinger.application.desmond.packages import traj
from schrodinger.structutils import analyze

RESTRAINT_KEY = 'restraint'
PERSISTENT_RESTRAINT_KEY = 'persistent'

_ENCODED_RESTRAINT_PROP = "s_desmond_restraint"

__all__ = [
    'CrossLinkGenerationError',
    'clear_encoded_restraints',
    'get_encoded_restraints',
    'set_encoded_restraints',
    'b64_encode',
    'b64_decode',
]


[docs]class CrossLinkGenerationError(Exception): pass
[docs]def clear_encoded_restraints(cms_sys: cms.Cms, keep_persistent: bool = True): """ Remove encoded restraints from the `cms_sys` :param cms_sys: "cms" to be processed. :type cms_sys: `cms.Cms` :param keep_persistent: "persistent" restraints disposition. :type keep_persistent: bool """ if keep_persistent: encoded = get_encoded_restraints(cms_sys) if encoded: dct = json.loads(b64_decode(encoded)) dct[RESTRAINT_KEY] = dct.get(PERSISTENT_RESTRAINT_KEY, {}) set_encoded_restraints(cms_sys, b64_encode(json.dumps(dct))) else: cms_sys.comp_ct[-1].property.pop(_ENCODED_RESTRAINT_PROP, None)
[docs]def get_encoded_restraints(cms_sys): """ Get encoded restraints from cms :type cms_sys: `cms.Cms` :rtype: `str` """ return cms_sys.comp_ct[-1].property.get(_ENCODED_RESTRAINT_PROP)
[docs]def set_encoded_restraints(cms_sys, restr): """ Store restraints in cms_sys object. :type cms_sys: `cms.Cms` :type restr: `str` """ cms_sys.comp_ct[-1].property[_ENCODED_RESTRAINT_PROP] = restr
[docs]def b64_encode(input_string: str) -> str: """ Encodes to a `str` rather than `bytes` so that the result can be set as a string property of a CT. :param input_string: string to be encoded :return: base64 encoded input """ return base64.b64encode(input_string.encode()).decode()
[docs]def b64_decode(input_string: str) -> str: return base64.b64decode(input_string)
def _check_asls(model: cms.Cms, ligand_asl: str, receptor_asl: str): """ :raise RuntimeError: If either ligand_asl or receptor_asl are not valid. """ def check_asl(asl, what_molecule): atoms = model.select_atom(asl) if 3 > len(atoms): raise RuntimeError( "ERROR: Expected the %s molecule to have at least 3 atoms, but " "found %d." % (what_molecule, len(atoms))) return set(atoms) ligand_atoms = check_asl(ligand_asl, "ligand") receptor_atoms = check_asl(receptor_asl, "receptor") if ligand_atoms & receptor_atoms: raise RuntimeError( """ERROR: ligand atoms and receptor atoms should NOT have overlaps. Ligand ASL expression: %s Receptor ASL expression: %s Atoms selected by both expressions: %s""" % (ligand_asl, receptor_asl, ligand_atoms & receptor_atoms)) def _check_restraint(restraint): if restraint is None: raise CrossLinkGenerationError("ERROR: Unable to find a suitable " "crosslink restraint. Check the " "trajectory for an unstable ligand.") # Centroid/Interaction utility functions # Make it hashable
[docs]@dataclasses.dataclass(eq=True, frozen=True) class PLInteractionAids: ligand_aid: int receptor_n_aid: int receptor_ca_aid: int receptor_c_aid: int @property def receptor_aids(self) -> Tuple[int, int, int]: return (self.receptor_n_aid, self.receptor_ca_aid, self.receptor_c_aid)
[docs]@dataclasses.dataclass class CentroidData: aids: List[int] # aids used to compute the centroid diff: np.ndarray # delta vector between coords and the centroid centroid: np.ndarray # centroid coords
def _get_centroid_data(ct: structure.Structure, asl: str) -> CentroidData: """ Return the data related to the centroid of the given asl. """ aids = analyze.evaluate_asl(ct, asl) struct = ct.extract(aids) coords = struct.getXYZ() centroid = np.mean(coords, axis=0) diff = coords - centroid return CentroidData(aids, diff, centroid) def _find_centroid_aid(ct: structure.Structure, asl: str) -> int: """ Return the atom index closest to the centroid. """ cd = _get_centroid_data(ct, asl) dist = np.linalg.norm(cd.diff, axis=1) index = np.argmin(dist) return cd.aids[index] def _find_aids_within_cutoff_of_centroid(ct: structure.Structure, asl: str, cutoff=2) -> List[int]: """ Return the atom indexes of that are within r_min + cutoff of the centroid. :param cutoff: Atoms within this distance to the centroid are returned. Default is 2 Angstrom. """ cd = _get_centroid_data(ct, asl) dist = np.linalg.norm(cd.diff, axis=1) index = np.argmin(dist) r_min = dist[index] centroid_aids = [] for i, d in enumerate(dist): if d <= r_min + cutoff: centroid_aids.append(cd.aids[i]) return centroid_aids def _find_max_dist_aid(ct: structure.Structure, asl: str, ref_aid: int) -> int: """ Return the atom index for the atom that is farthest from the reference atom. :param asl: The asl for atoms to search. :param ref_aid: The atom index of the reference atom. """ aids = analyze.evaluate_asl(ct, asl) ref_st = ct.extract([ref_aid]) ref_coords = ref_st.getXYZ() struct = ct.extract(aids) coords = struct.getXYZ() diff = coords - ref_coords dist = np.linalg.norm(diff, axis=1) index = np.argmax(dist) return aids[index] def _find_axis_aids(ct: structure.Structure, asl: str, aid1: int, aid2: int, aid3: int, delta1=30, delta2=60) -> List[int]: """ Given indexes of 3 atoms (not colinear), find index of atoms (aid4) so that the aid2-aid3-aid4 angle is 90+/-delta1 and the aid1-aid2-aid3-aid4 dihedral is 90+/-delta2. :param delta1: interval for angle, default is 30 degrees. :param delta2: interval for dihedral, default is 60 degrees. :return: aid1, aid2, aid3: indexes for 3 atoms or empty list if no match found. """ min_ang = 90 - delta1 max_ang = 90 + delta1 min_dih = 90 - delta2 max_dih = 90 + delta2 at1 = ct.atom[aid1] at2 = ct.atom[aid2] at3 = ct.atom[aid3] angle = ct.measure(at1, at2, at3) # Check for duplicate atoms if len(set([at1, at2, at3])) != 3: return [] # Check for colinear atoms if angle < 5 or angle > 175: return [] group_ids = [] for i in analyze.evaluate_asl(ct, asl): at4 = ct.atom[i] angle = ct.measure(at2, at3, at4) dihed = ct.measure(at1, at2, at3, at4) if angle > min_ang and angle < max_ang and \ dihed > min_dih and dihed < max_dih: group_ids.append(i) return group_ids def _get_two_atoms(ct: structure.Structure, aid: int) -> List[Tuple[int, int]]: """ Given a structure and an aid, return all pairs of connected heavy atoms that are bonded to the given atom or bonded to the adjacent atom. """ aid_pairs = [] at1 = ct.atom[aid] for at2 in at1.bonded_atoms: if _bonded_heavy_atom_count(at2) >= 2: for at3 in list(at1.bonded_atoms) + list(at2.bonded_atoms): if at3 != at1 and at3 != at2 and _bonded_heavy_atom_count( at3) >= 2: aid_pairs.append((at2.index, at3.index)) return aid_pairs def _bonded_heavy_atom_count(at: structure._structure._StructureAtom) -> int: # Return the number of heavy atoms attached to the given atom. return sum(bonded_at.atomic_number > 1 for bonded_at in at.bonded_atoms) def _get_heavy_asl(asl: str) -> str: return f'({asl}) and (not a.e H)' def _get_backbone_aids(ct: structure.Structure, receptor_aid: int) -> Tuple[int, int, int]: """ Given a structure and a protein atom index, return the corresponding N, Ca, C backbone aids. """ res = ct.atom[receptor_aid].getResidue() n = res.getBackboneNitrogen() ca = res.getAlphaCarbon() c = res.getCarbonylCarbon() if not all([n, ca, c]): return (None, None, None) return (n.index, ca.index, c.index) def _get_heavy_aid(ct: structure.Structure, aid: int) -> Optional[int]: """ Return the heavy atom that is attached to the given `aid` and is not terminal. Return None if such an atom could not be found. """ atm = ct.atom[aid] # Avoid hydrogen and terminal heavy atoms # Need to track this to prevent an infinite loop with CH3-CH3 searched_atoms = set() while True: nbond_heavy = _bonded_heavy_atom_count(atm) if nbond_heavy == 0: return None elif nbond_heavy == 1: for aa in atm.bonded_atoms: if aa.atomic_number > 1: if aa in searched_atoms: return None atm = aa searched_atoms.add(atm) else: break return atm.index def _is_so2(ct: structure.Structure, aid: int) -> bool: """ Return True if the atom is sulfur and it is bonded to two oxygens. """ lig_at = ct.atom[aid] num_oxygen = 0 if lig_at.atomic_number == 16 and _bonded_heavy_atom_count(lig_at) == 4: num_oxygen = sum(ati.atomic_number == 8 for ati in lig_at.bonded_atoms) return num_oxygen == 2 def _get_protein_ligand_interaction_freq_dict( msys_model: "msys.System", # noqa: F821 cms_model: cms.Cms, tr: List["traj.TrajFrame"], ligand_asl: str, receptor_asl: str, num_traj_segments: int = 1) -> Dict[PLInteractionAids, List[float]]: """ Return the normalized frequency for the interactions between the protein and ligand. The keys are `PLInteractionAids`. :param num_traj_segments: Number of segments to split the trajectory into prior to running the analysis. The frequencies are computed for each segment. """ nfr = (len(tr) + 1) // num_traj_segments trs = [tr[i * nfr:(i + 1) * nfr] for i in range(num_traj_segments)] # ASLs receptor_aids = analyze.evaluate_asl(cms_model, receptor_asl) ligand_aids = analyze.evaluate_asl(cms_model, ligand_asl) # Run protein-ligand interaction analyzers freqs = [] for tr in trs: analyzer1 = analysis.HydrogenBondFinder(msys_model, cms_model, receptor_aids, ligand_aids) analyzer2 = analysis.SaltBridgeFinder(msys_model, cms_model, receptor_aids, ligand_aids) result1 = analysis.analyze(tr, analyzer1) result2 = analysis.analyze(tr, analyzer2) results = [a + b for a, b in zip(result1, result2)] # Get the frequency for each pair of interactions freq = Counter() for result in results: for (pro_aid, lig_aid) in result: if lig_aid not in ligand_aids: pro_aid, lig_aid = lig_aid, pro_aid pro_n_aid, pro_ca_aid, pro_c_aid = _get_backbone_aids( cms_model.fsys_ct, pro_aid) if pro_n_aid is None: # Skip interaction with terminal residue continue lig_aid = _get_heavy_aid(cms_model.fsys_ct, lig_aid) if lig_aid is None: print('Could not find attached heavy atom.') return None pair = PLInteractionAids(lig_aid, pro_n_aid, pro_ca_aid, pro_c_aid) freq[pair] += 1 freqs.append(freq) # Corner case: interaction missing from trajectory segment. # Make sure all segments have the same interaction keys. pairs = {pair for freq in freqs for pair in freq} for freq in freqs: for pair in pairs: if pair not in freq: freq[pair] = 0.0 # No hydrogen bond or salt bridge interactions found if not pairs: return None # Normalize the frequency for each segment and store as a list result = defaultdict(list) for freq, tr in zip(freqs, trs): for k in freq.keys(): freq[k] = freq[k] / (len(tr) or 1) result[k].append(freq[k]) return result