Source code for schrodinger.application.desmond.stage.app.absolute_binding.restraint

from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np

from schrodinger import structure
from schrodinger.adapter import evaluate_smarts
from schrodinger.application.desmond import cms
from schrodinger.application.desmond.constants import FEP_RESTRAIN
from schrodinger.application.desmond.constants import REST_REGION_RULE
from schrodinger.application.desmond.stage.prepare import forcefield
from schrodinger.structutils import analyze

if TYPE_CHECKING:
    from schrodinger.application.desmond.packages import traj
    from schrodinger.application.scisol.packages.fep import graph


[docs]def prepare_ligand_restraints(st: structure.Structure, ligand_asl: str, force_constants: List[float], sigmas: List[float], schedule_name: str) -> str: """ Given a structure, return the dihedral restraint parameters encoded as a string. This adds restraints with fixed sigma to any six membered rings present in the ligand. :param st: Simulation structure on which the restraints will be applied. :param ligand_asl: ASL to identify the ligand. :param force_constants: Force constants for the restraint at the lambda endpoints. :param sigmas: Sigma values for the restraint at the lambda endpoints. :param schedule_name: Name of the schedule associated with the restraint. :return: The restraints encoded as a string. """ ligand_aids = analyze.evaluate_asl(st, ligand_asl) ligand_st = st.extract(ligand_aids, copy_props=True) forcefield.add_restraint_atom_marker(ligand_st, ligand_asl) six_st = _extract_six_member_rings(ligand_st) if six_st is None: print("Could not find six membered rings to restrain.") return '' torsions = _get_heavy_atom_torsions(six_st) restraints = [ _prepare_restraint(torsion, ligand_asl, force_constants, sigmas, schedule_name) for torsion in torsions ] return '[' + '\n'.join(restraints) + ']'
[docs]def prepare_adaptive_ligand_restraints( msys_model: "msys.System", # noqa: F821 cms_model: cms.Cms, trj: List["traj.TrajFrame"], ligand_asl: str, force_constants: List[float], schedule_name: str) -> str: """ Given a cms model and corresponding trajectory, return the dihedral restraint parameters encoded as a string. This adds restraints with adaptive width to all rotatable ligand torsions. This excludes torsions that are part of a ring. Adaptive width means the sigma values are calculated from the dihedral values for each rotatable torsion over the MD simulation that was run. :param msys_model: Msys model used for trajectory analysis. :param cms_model: Cms model read in with `traj_util.read_cms_and_traj`. :param trj: Trajectory read in with `traj_util.read_cms_and_traj`. :param ligand_asl: ASL to identify the ligand. :param force_constants: Force constants for the restraint at the lambda endpoints. :param schedule_name: Name of the schedule associated with the restraint. :return: The restraints encoded as a string. """ torsions, torsion_trj = _analyze_torsions(msys_model, cms_model, trj, ligand_asl) if torsions is None: print( "WARNING: Could not find torsions to determine adaptive ligand restraints." ) return '' ref_values, sigmas = _get_adaptive_torsion_parameters(torsion_trj) restraints = [] for torsion, ref_value, sigma in zip(torsions, ref_values, sigmas): sigmas = [sigma, sigma] restraints.append( _prepare_restraint(torsion, ligand_asl, force_constants, sigmas, schedule_name)) return '[' + '\n'.join(restraints) + ']'
[docs]def use_representative_frame( msys_model: "msys.System", # noqa: F821 cms_model: cms.Cms, trj: List["traj.TrajFrame"], ligand_asl: str): """ Given a cms model and corresponding trajectory, update the `cms_model` with the representative frame coordinates. This is calculated based on the ligand torsions. :param msys_model: Msys model used for trajectory analysis. :param cms_model: Cms model read in with `traj_util.read_cms_and_traj`. The coordinates are updated in place with those from the representative frame. :param trj: Trajectory read in with `traj_util.read_cms_and_traj`. :param ligand_asl: ASL to identify the ligand. """ from schrodinger.application.desmond.packages import topo torsions, torsion_trj = _analyze_torsions(msys_model, cms_model, trj, ligand_asl) if torsions is None: print( "WARNING: Could not find torsions to determine representative frame." ) return ref_values, _ = _get_adaptive_torsion_parameters(torsion_trj) repr_frame_idx = _get_representative_frame(torsion_trj, ref_values) # Upate the cms_model with the representative frame if cms_model.need_msys: topo.update_cms(cms_model, trj[repr_frame_idx]) else: cms_model.update_with_frame(trj[repr_frame_idx])
def _analyze_torsions( msys_model: "msys.System", # noqa: F821 cms_model: cms.Cms, trj: List["traj.TrajFrame"], ligand_asl: str) -> Tuple[List[List[int]], np.ndarray]: from schrodinger.application.desmond.packages import analysis st = cms_model.fsys_ct.copy() restraint_to_atom_idx = forcefield.add_restraint_atom_marker(st, ligand_asl) # Get unique rotatable torsions for the ligand ligand_st = st.extract(list(restraint_to_atom_idx.values()), copy_props=True) torsions = _get_unique_rotatable_heavy_atom_torsions(ligand_st) if not torsions: return None, None # Calculate the dihedral angle values for each torsion torsion_aids = [ [restraint_to_atom_idx[t] for t in torsion] for torsion in torsions ] torsion_analyzers = [ analysis.Torsion(msys_model, cms_model, *torsion_aid) for torsion_aid in torsion_aids ] # [N_torsions, N_frames] torsion_trj = np.array(analysis.analyze(trj, *torsion_analyzers)) # With 1 torsion, analyze squeezes the result to [N_frames] if len(torsion_trj.shape) == 1: torsion_trj = np.reshape(torsion_trj, (1, torsion_trj.shape[0])) return torsions, torsion_trj def _extract_six_member_rings( ligand_st: structure.Structure) -> structure.Structure: """ Return a structure containing just the six member rings. :param ligand_st: Ligand structure to analyze. """ six_nonaromatic_aids = [ a[0] for a in evaluate_smarts(ligand_st, "[r6;!a;R1;!$(A=A)]") ] if not six_nonaromatic_aids: return None six_nonaromatic_st = ligand_st.extract(six_nonaromatic_aids, copy_props=True) # If two cyclohexyl groups are linked together, they can have a rotatable # bond connecting them. In this case, separate the rings by # deleting the rotatable bonds here. for b in list(analyze.rotatable_bonds_iterator(six_nonaromatic_st)): six_nonaromatic_st.deleteBond(*b) # If the six member ring has some double bonds or is fused, # it can pass the above check but will not pass here. six_ring_aids = [a[0] for a in evaluate_smarts(six_nonaromatic_st, "[r6]")] if not six_ring_aids: return None six_st = six_nonaromatic_st.extract(six_ring_aids, copy_props=True) return six_st def _get_heavy_atom_torsions(st: structure.Structure) -> List[List[int]]: """ Return a list of heavy atom torions for the input structure. Each torsion contains the corresponding `FEP_RESTRAIN` property for the four atoms that make up the dihedral. :param st: Ligand structure to analyze. """ included_atoms = [a for a in st.atom if a.atomic_number > 1] result = [] for t in analyze.torsion_iterator(st, atoms=included_atoms): result.append([ st.atom[t[0]].property[FEP_RESTRAIN], st.atom[t[1]].property[FEP_RESTRAIN], st.atom[t[2]].property[FEP_RESTRAIN], st.atom[t[3]].property[FEP_RESTRAIN], ]) return result def _prepare_restraint(torsion: List[int], ligand_asl: str, force_constants: List[float], sigmas: List[float], schedule_name: str) -> str: """ Given a torsion, output a string which encodes the restraint parameters. :param torsion: Torsion specified as a list of integers matching the `FEP_RESTRAIN` property. :param ligand_asl: ASL to identify the ligand. :param force_constants: Force constants for the restraint at the lambda endpoints. :param sigmas: Sigma values for the restraint at the lambda endpoints. :param schedule_name: Name of the schedule associated with the restraint. """ atom0 = f'"({ligand_asl}) and a.{FEP_RESTRAIN} {torsion[0]}"' atom1 = f'"({ligand_asl}) and a.{FEP_RESTRAIN} {torsion[1]}"' atom2 = f'"({ligand_asl}) and a.{FEP_RESTRAIN} {torsion[2]}"' atom3 = f'"({ligand_asl}) and a.{FEP_RESTRAIN} {torsion[3]}"' atoms = [atom0, atom1, atom2, atom3] return f"""{{name = alchemical_improper_fbhw atoms = [{' '.join(atoms)}] force_constants = [{force_constants[0]} {force_constants[1]}] sigmaA = {sigmas[0]} sigmaB = {sigmas[1]} schedule = {schedule_name} }}""" def _get_heaviest_connected_atom(st: structure.Structure, input_atom: int, excluded_atom_idx: int) -> Optional[int]: """ Return the heaviest atom connected to the `input_atom` that excludes terminal atoms and the `excluded_atom_idx`. If no atom could be found, return `None`. """ connected_atom_idx = None for at in st.atom[input_atom].bonded_atoms: if at.bond_total > 1 and at.index != excluded_atom_idx: if not connected_atom_idx: connected_atom_idx = at.index elif at.atomic_number > st.atom[connected_atom_idx].atomic_number: connected_atom_idx = at.index return connected_atom_idx def _get_unique_rotatable_heavy_atom_torsions( st: structure.Structure) -> List[List[int]]: """ Return a list of heavy atom rotatable torions for the input structure. If multiple torsions contain the same rotatable bond, only include the one involving the heavier atom. Each torsion contains the corresponding `FEP_RESTRAIN` property for the four atoms that make up the dihedral. :param st: Ligand structure to analyze. """ result = [] for i1, i2 in analyze.rotatable_bonds_iterator(st): at1 = st.atom[i1] at2 = st.atom[i2] # Find the heaviest atom connected to atom 1 (exclude terminal atoms) i0 = _get_heaviest_connected_atom(st, i1, i2) # Find the heaviest atom connected to atom 2 (exclude terminal atoms) i3 = _get_heaviest_connected_atom(st, i2, i1) if i0 and i1 and i2 and i3: result.append([ st.atom[i0].property[FEP_RESTRAIN], st.atom[i1].property[FEP_RESTRAIN], st.atom[i2].property[FEP_RESTRAIN], st.atom[i3].property[FEP_RESTRAIN], ]) return result def _get_adaptive_torsion_parameters( torsion_trj: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Given the N_torsions x N_frames matrix `torsion_trj`, determine the center and sigmas for the torsion restraints. :param torsion_trj: The N_torsions x N_frames matrix containing the dihedral angle values for each torsion over time. """ # Sort the trajectories sorted_torsion_trj = np.sort(torsion_trj, axis=-1) # Determine the torsion interval needed to cover the other torsions torsion_interval = 360 - (sorted_torsion_trj[:, 1:] - sorted_torsion_trj[:, :-1]) last_torsion_interval = np.expand_dims( sorted_torsion_trj[:, -1] - sorted_torsion_trj[:, 0], -1) torsion_interval = np.concatenate([last_torsion_interval, torsion_interval], axis=1) # Find the interval that covers all dihedral angles for each trajectory min_interval = np.amin(torsion_interval, -1) min_idxs = np.expand_dims(np.argmin(torsion_interval, -1), -1) torsion_start = np.take_along_axis(sorted_torsion_trj, min_idxs, axis=-1) torsion_start = np.reshape(torsion_start, (-1,)) # Determine the reference value for the restraint ref_values = np.fmod(torsion_start + min_interval / 2.0, 360.0) ref_values[ref_values > 180] -= 360.0 ref_values[ref_values <= -180] += 360.0 # The width is the midpoint of the smallest range of values epsilon = 5.0 sigmas = np.clip(min_interval / 2.0 + epsilon, 0, 180.0) return ref_values, sigmas def _get_representative_frame(torsion_trj, ref_values) -> int: """ Given the trajectory of torsions and corresponding ref_values, return the frame index which is closest to the ref_values. :param torsion_trj: The N_torsions x N_frames matrix containing the dihedral angle values for each torsion over time. :param ref_values: Array containing the reference dihedral angle value for each torsion. """ # The representative frame index is the one with the smallest # deviation to the center torsion_diff = torsion_trj - np.expand_dims(ref_values, axis=-1) torsion_diff_abs = np.abs(torsion_diff) torsion_diff_shifted = 360 - torsion_diff_abs # Pick the smaller value as the distance torsion_diff = np.min([torsion_diff_abs, torsion_diff_shifted], axis=0) # Sum across torsions # [N_frame] sum_diff = np.sum(torsion_diff, axis=0) # Frame with the smallest difference is representative return np.argmin(sum_diff)
[docs]def overwrite_hotatoms(graph: 'graph.Graph', has_ligand_restraint: bool): """ Overwrite the hot atoms settings on the graph in accordance with the graphs ligand restraints """ from schrodinger.application.scisol.packages.fep import hot_atom rule = (REST_REGION_RULE.DEFAULT if has_ligand_restraint else REST_REGION_RULE.ALL) hot_atom.overwrite_hotatoms(graph.edges_iter(), leg_solvent=rule, leg_complex=REST_REGION_RULE.DEFAULT)