Source code for schrodinger.application.desmond.packages.restraint.builder

"""
Restraint generation for cross-CT terms and all terms supported
by desmond backend, including alchemical terms.

Either a single term or a generator can be used.
For single term, each selection corresponds to a single atom.

Two kinds of generators are implemented now:

product: The product of all selections is used to generate all
         the terms.  Use case is to keep alchemical ions way from
         places they may get stuck.

connected: One selection is evaluated to generate terms for bond, angle
           and torsion.  Use case is the alchemical restraints
           on protein conformations.

Reference distance, angle and torsion values are computed for generated
terms.  For alchemical terms, reference coordinates saved previously will
be used for these calculations if available.

Copyright Schrodinger, LLC. All rights reserved.
"""

import copy
import dataclasses
import enum
import functools
import itertools
import json
import pprint
from typing import List
from typing import Tuple

import numpy as np

from schrodinger.application.desmond import cms
from schrodinger.application.desmond import constants
from schrodinger.application.desmond.packages import msys
from schrodinger.structutils import analyze
from schrodinger.structutils import measure
from schrodinger.utils import log
from schrodinger.utils import sea

from .utils import PERSISTENT_RESTRAINT_KEY
from .utils import RESTRAINT_KEY
from .utils import b64_decode
from .utils import b64_encode
from .utils import get_encoded_restraints
from .utils import set_encoded_restraints

logger = log.get_logger(name="restraint_builder")

_ATOMS_KEY = "atoms"

_TERM_ITERATORS = {
    2: analyze.bond_iterator,
    3: analyze.angle_iterator,
    4: analyze.torsion_iterator,
}

_MEASURE_FUNCTIONS = {
    2: measure.measure_distance,
    3: measure.measure_bond_angle,
    4: measure.measure_dihedral_angle,
}

_FBHW = 'fbhw'
_FC = 'fc'
_POSRE_HARM = 'posre_harm'
_POSRE_FBHW = 'posre_fbhw'
_SIGMA = 'sigma'

_POSRE_HARM_OR_FBHW = frozenset({_POSRE_HARM, _POSRE_FBHW})


[docs]class GeneratorType(enum.Enum): PRODUCT = 'product' CONNECTED = 'connected'
_dummy_msys = msys.CreateSystem() # empty msys to get table schemas __all__ = [ 'AtomID', 'Restraints', 'RestraintBuilder', 'generate_conf_pose_restraints', 'has_positional_restraints', ]
[docs]def decode_atom_ids(encoded_ids): return map(AtomID.make, encoded_ids)
[docs]def get_natoms_in_term(table_name: str) -> int: """ Returns arity (number of atoms for each term) for the with name `table_name`. :param table_name: name of the desmond term table :return: number of atoms for each term """ # addTableFromSchema only generate a new table once # if the named table already exists, a pointer will be returned # c.f. http://opengrok/xref/desmond-gpu-src/other/msys/python/__init__.py#1277 table = _dummy_msys.addTableFromSchema(table_name) return table.natoms
[docs]def get_table_schema(table_name: str): """ Returns schema for desmond term table with name `table_name`. :param table_name: name of the desmond term table :return: (param_props, table_props) where both param_props and table_props are frozensets of property name strings :rtype: `tuple(frozenset(str), frozenset(str))` """ # the following only adds table when named table does not exist # see link in the above function table = _dummy_msys.addTableFromSchema(table_name) term_props = (k for k in table.term_props if k != 'constrained') return (frozenset(table.params.props), frozenset(term_props))
[docs]@dataclasses.dataclass class AtomID: """ Atom is specified by two numbers: ct index and atom index within this ct; ct indices starts from 0; ct == 0 indicates that the atom number is a gid used by desmond backend; ct == 1 is the "full system"; ct >= 2 correspond to the component cts """ ct: int atom: int
[docs] @staticmethod def make(v): return AtomID( **v) if isinstance(v, dict) else AtomID(ct=v[0], atom=v[1])
class _GenericParams(dict): def __init__(self, *, table_name: str, dct=None): """ :param table_name: desmond term table name :type table_name: str :param dct: dictionary deserialized from JSON :type dct: dict """ self._is_alchemical = table_name.startswith('alchemical_') props = get_table_schema(table_name) for p in sorted(set([_ATOMS_KEY]) | props[0] | props[1]): self[p] = [] if dct: for k, v in dct.items(): if k == _ATOMS_KEY: self[k] += [list(decode_atom_ids(atoms)) for atoms in v] else: self[k] += v def addTerm(self, atoms: List[AtomID], props: dict) -> int: """ props contain the actual force-field parameters, force constants, equilibrium angles and other parameters that specific to the term, e.g. schedule for alchemical terms :param atoms: list of atom ids :param props: dictionary of parameters keyed by name of the parameter (`str`), including both table properties (e.g. schedule) and force field parameter properties :return: index of the term added """ atoms_list = self[_ATOMS_KEY] atoms_list.append(atoms) for k, v in props.items(): self[k].append(v) return len(atoms_list) - 1 @property def is_positional(self): return False @property def is_alchemical(self): return self._is_alchemical def incorporate(self, other: '_GenericParams'): """ merges `other` into `self` """ assert self.keys() == other.keys() if self[_ATOMS_KEY] and other[_ATOMS_KEY]: raise NotImplementedError( "merging is supported for positional restraints only") elif other[_ATOMS_KEY]: for k, v in self.items(): v += copy.deepcopy(other[k]) class _PosreHarmParams: """ Stores group of position restraints as a numpy structured array for faster creation (if from existing arrays) and merging. """ AID_TYPE = [('ct', 'i8'), ('atom', 'i8')] DTYPE = [("atom_id", AID_TYPE), ("k", "f8", (3,)), ("ref", "f8", (3,))] INTERSECT_KEYS = "atom_id" def __init__(self, *, atom_ids=None, k=None, ref=None, dct=None, **kwargs): """ :param atom_ids: list of lists of AtomIDs :param k: force constants :param ref: reference positions :param dct: dictionary de-serialized from json; if truthy, overrides other arguments """ if dct: atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY]) ref = list(zip(dct['x0'], dct['y0'], dct['z0'])) k = list(zip(dct['fcx'], dct['fcy'], dct['fcz'])) num_entries = len(atom_ids or []) self.arr = np.empty(shape=num_entries, dtype=self.DTYPE) if num_entries: self.arr['atom_id'] = atom_ids self.arr['ref'] = ref self.arr['k'] = k @staticmethod def _transcode_atom_ids(list_of_encoded_aid_lists): """ Helper function that takes list-of-lists of atom IDs encoded either as (ct, atom) pairs or {'ct': ct, 'atom': atom} dicts and returns list of (ct, atom) tuples (see c-tor). """ list_of_tuples = [] for encoded_aids in list_of_encoded_aid_lists: aid = next(decode_atom_ids(encoded_aids)) list_of_tuples.append(dataclasses.astuple(aid)) return list_of_tuples def _intersect_with(self, other) -> (np.array, np.array): """ Find the duplicate entries in self and other and return their indices, respectively """ _, self_idx, other_idx = np.intersect1d(self.arr[self.INTERSECT_KEYS], other.arr[self.INTERSECT_KEYS], assume_unique=True, return_indices=True) return self_idx, other_idx def incorporate(self, other: '_PosreHarmParams'): """ merges `other` into `self` """ assert isinstance(other, _PosreHarmParams) self_idx, other_idx = self._intersect_with(other) other_k = other.arr['k'][other_idx] self.arr['k'][self_idx] = np.maximum(self.arr['k'][self_idx], other_k) self.arr['ref'][self_idx] = other.arr['ref'][other_idx] mask = np.ones(other.arr.size, dtype=bool) mask[other_idx] = False self.arr = np.concatenate([self.arr, other.arr[mask]]) def copyRefs(self, other: '_PosreHarmParams'): """ copies reference positions from the `other` """ assert isinstance(other, _PosreHarmParams) self_idx, other_idx = self._intersect_with(other) self.arr['ref'][self_idx] = other.arr['ref'][other_idx] @property def atom_ids_as_tuples(self): return [ [(int(aid['ct']), int(aid['atom']))] for aid in self.arr['atom_id'] ] @property def refs(self): for a, v in zip('xyz', self.arr['ref'].transpose().tolist()): yield f'{a}0', v def asdict(self): outcome = {_ATOMS_KEY: self.atom_ids_as_tuples} outcome.update(self.refs) for a, v in zip('xyz', self.arr['k'].transpose().tolist()): outcome[f'{_FC}{a}'] = v return outcome def __len__(self): return len(self.arr) @property def is_positional(self): return True @property def is_alchemical(self): return False class _PosreFBHWParams(_PosreHarmParams): DTYPE = [("atom_id", _PosreHarmParams.AID_TYPE), ("k", "f8"), ("ref", "f8", (3,)), (_SIGMA, "f8")] INTERSECT_KEYS = ["atom_id", _SIGMA] def __init__(self, *, atom_ids=None, k=None, ref=None, sigma=None, dct=None, **kwargs): if dct: atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY]) ref = list(zip(dct['x0'], dct['y0'], dct['z0'])) k = dct[_FC] sigma = dct[_SIGMA] super().__init__(atom_ids=atom_ids, k=k, ref=ref) if sigma: # will broadcast as needed self.arr[_SIGMA] = sigma def asdict(self): outcome = { _ATOMS_KEY: self.atom_ids_as_tuples, _FC: self.arr['k'].tolist(), _SIGMA: self.arr[_SIGMA].tolist() } outcome.update(self.refs) return outcome class _JsonEncoder(json.JSONEncoder): @functools.singledispatchmethod def default(self, obj): return super().default(obj) @default.register def _(self, obj: AtomID): return dataclasses.astuple(obj) @default.register def _(self, obj: _PosreHarmParams): return obj.asdict()
[docs]class Restraints: """ Holds restraint terms parameters. Assumes that "persistent" (aka "permanent") tables support merging (only "posre_*" at the moment). """
[docs] def __init__(self, *, text=None): """ :param text: pre-existing serialized restraints (as json) to build upon :type text: str or NoneType """ data = { RESTRAINT_KEY: {}, PERSISTENT_RESTRAINT_KEY: {} } if text is None else json.loads(text) def instantiate(d): return { name: self._getTableClass(name)(table_name=name, dct=dct) for name, dct in d.items() } self._restraints = instantiate(data.get(RESTRAINT_KEY, {})) self._persistent = instantiate(data.get(PERSISTENT_RESTRAINT_KEY, {}))
@staticmethod def _getTableClass(table_name): return { _POSRE_HARM: _PosreHarmParams, _POSRE_FBHW: _PosreFBHWParams, }.get(table_name, _GenericParams)
[docs] def getTable(self, table_name: str, persistent: bool = False) -> object: """ Gets parameters table by name. :param table_name: name of the desmond term table :param persistent: persistent vs regular table :return: requested table :rtype: `_GenericParams` or `_PosreHarmParams` or `_PosreFBHWParams` """ group = self._persistent if persistent else self._restraints try: return group[table_name] except KeyError: table_class = self._getTableClass(table_name) table = table_class(table_name=table_name) group[table_name] = table return table
[docs] def addTerm(self, table_name: str, atoms: List[AtomID], props: dict, persistent: bool = False) -> int: """ Adds single restrain term. table_name is the desmond interaction table, stretch_harm, alchemical_improper_harm etc. An atom is specified by two numbers, ct number and atom number in ct. ct number starts from 0, that means the atom number is gid used by desmond backend. ct number 1 means full system. ct numbers greater than or equal to 2 mean component cts. props contain the actual force-field parameters, force constants, equilibrium angles and other parameters that specific to the term, e.g. schedule for alchemical terms :param table_name: name of the table, this is one of the term tables supported by desmond :param atoms: atom ids :param props: dictionary of parameter keyed by name of the parameter (`str`), including both table properties (e.g. schedule) and force field parameter properties. :param persistent: is this a persistent term :return: index of the term added """ table = self.getTable(table_name, persistent) return table.addTerm(atoms, props)
@property def has_persistent(self) -> bool: return bool(self._persistent)
[docs] def toJson(self) -> str: """ :return: json string to be loaded by msys """ if self.has_persistent: restraints = copy.deepcopy(self._restraints) for name, persistent_table in self._persistent.items(): if table := restraints.get(name): table.incorporate(persistent_table) restraints[name] = table else: restraints[name] = persistent_table dct = { RESTRAINT_KEY: restraints, PERSISTENT_RESTRAINT_KEY: self._persistent } else: dct = {RESTRAINT_KEY: self._restraints} return json.dumps(dct, cls=_JsonEncoder)
@property def has_positional(self) -> bool: """ Set to True if the Restraint has any positional restraints. """ return any(table.is_positional for table in itertools.chain( self._restraints.values(), self._persistent.values()))
def _measure_ref(atoms): """ Measure wrapper for distance, angle and torsion. :param atoms: list/tuple of atoms :type atoms: list of `structure.Atom` or list of 3 floats :rtype: float """ n_atom = len(atoms) func = _MEASURE_FUNCTIONS.get(n_atom) if func: return func(*atoms) else: raise ValueError("Cannot measure on %d number of atoms" % n_atom) def _generic_terms(struct, atom_selection, n_atom): """ Enumerate terms from atom selection according to number of atoms n_atom 2: bond, 3: angle, 4: dihedral. :param struct: structure for connection :type struct: `structure.Structure` :param atom_selection: selected atoms :type atom_selection: list of atom indices :param n_atom: number of atoms in a term :type n_atom: integer value of 2, 3 or 4 :rtype: list of list """ it = _TERM_ITERATORS.get(n_atom) if it: return it(struct=struct, atoms=atom_selection) else: raise ValueError("Unknown term using %d atoms" % n_atom) def _select_atoms_single_term(cms_mol, atom_sel, n_atoms): """ Select atoms for a single term, only the first atom selected in each CT will be used. Atoms returned are in the form of (CT number, `structure.Atom`) :param cms_mol: input system for selection :type cms_mol: `cms.Cms` object :param atom_sel: atom selection specification :type atom_sel: `sea.List` object :param n_atoms: number of atoms in term :type n_atoms: `int` :rtype: list of atom tuples, (`int`, `structure.Atom`) """ atoms_ret = [] for i in range(n_atoms): for ct_idx, atoms in enumerate(cms_mol.select_atom_comp( atom_sel[i].val), start=2): if len(atoms) == 1: atoms_ret.append( (ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[atoms[0]])) break elif len(atoms) > 1: raise ValueError( "More than 1 atom selected for the term, selection: %s" % atom_sel[i].val) if len(atoms_ret) == n_atoms: return atoms_ret else: raise ValueError( "Selected atoms do not match required %d versus %d, selection: %s\n" % (len(atoms_ret), n_atoms, atom_sel)) def _select_atoms_for_generator(cms_mol, atom_sel): """ Select atoms to be used for actual term generation. Atoms retuned are in (CT number, `structure.Atom`) form. :param cms_mol: input system for selection :type cms_mol: `cms.Cms` :param atom_sel: atom selection specification, ASL :type atom_sel: `sea.Map` :rtype: list of atom tuples, (`int`, `structure.Atom`) """ atoms_ret = [] for ct_idx, atoms in enumerate(cms_mol.select_atom_comp(atom_sel.val), start=2): for a in atoms: atoms_ret.append((ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[a])) return atoms_ret def _find_force_constants(param_props: List[str]) -> Tuple[List[str]]: """ Find all force constant keys in parameter properties and return tuple of force constants and other properties. :param param_props: force field parameters """ fcs = [] ref = [] for p in param_props: if p.startswith(_FC): fcs.append(p) else: ref.append(p) return (sorted(fcs), sorted(ref))
[docs]class RestraintBuilder:
[docs] def __init__(self, restraint_terms: sea.List, existing: constants.EXISTING_RESTRAINT, cms_sys: cms.Cms, persistent: bool = False): """ :param restraint_terms: all restraint terms to be added :param existing: One of `constants.EXISTING_RESTRAINT`, determines whether to `IGNORE` current restraints and replace them with `restraint_terms` or `RETAIN` them and update them with `restraint_terms`. :param cms_sys: cms object for molecules :param persistent: build "persistent" restraints """ self._cms_sys = cms_sys self._all_restraints = restraint_terms self._persistent = persistent encoded = get_encoded_restraints(cms_sys) if encoded: self._restrain = Restraints(text=self._applyRestraintsDisposition( b64_decode(encoded), existing)) else: self._restrain = Restraints() fep_cts = cms_sys.get_fep_cts() coords = (None, None) if all(fep_cts): try: from schrodinger.application.scisol.packages.core_hopping.int_fepio import \ get_reference_coordinates_for_two_molecules wt_ct, mut_ct = fep_cts coords = get_reference_coordinates_for_two_molecules( wt_ct, mut_ct) if wt_ct is not None and mut_ct is not None: logger.debug( "Reference coordinates found, using them for alchemical terms." ) except ImportError: logger.debug( "Cannot import scisol function, use current coordinates.") except RuntimeError: # get_reference_coorrdinates_for_two_molecules throws RuntimeError pass fep_ref_coord = {k: v for k, v in zip(fep_cts, coords) if v is not None} self._ref_coords = { ct_idx: fep_ref_coord[ct] if ct in fep_ref_coord else ct.getXYZ() for ct_idx, ct in enumerate(cms_sys.comp_ct, start=2) }
def _applyRestraintsDisposition( self, text: str, existing: constants.EXISTING_RESTRAINT) -> str: """ :param text: Serialized restraints (see `Restraints.toJson()`). :param existing: Restraints disposition (ignore/retain/ignore_posre). :return: Serialized restraints ready for `Restraints` constructor. """ retain_all = existing == constants.EXISTING_RESTRAINT.RETAIN ignore_posre = existing == constants.EXISTING_RESTRAINT.IGNORE_POSRE dct = json.loads(text) stage_restraints = dct.get(RESTRAINT_KEY, {}) persistent = dct.get(PERSISTENT_RESTRAINT_KEY, {}) group = persistent if self._persistent else stage_restraints if retain_all or ignore_posre: if ignore_posre: for k in _POSRE_HARM_OR_FBHW: group.pop(k, None) else: group.clear() return json.dumps({ RESTRAINT_KEY: stage_restraints, PERSISTENT_RESTRAINT_KEY: persistent }) def _addAllPosreTerms(self, spec: 'sea.Map'): """ Add positional restraint terms as directed by the `spec`. :param spec: restraint specs for posre_harm or posre_fbhw """ fc = spec.force_constants.val ref = spec.ref.val if 'ref' in spec else 'reset' model = self._cms_sys atom_idxs = model.select_atom_comp(spec.atoms.val if isinstance( spec.atoms, sea.Atom) else spec.atoms[0].val) table_name = spec.name.val table = self._restrain.getTable(table_name, self._persistent) make_table = functools.partial( _PosreFBHWParams, sigma=spec.sigma.val) if _FBHW in table_name else _PosreHarmParams for ct_idx, (ct, ct_atom_idxs) in enumerate(zip(model.comp_ct, atom_idxs), start=2): xyz = ct.getXYZ(copy=False)[np.array(ct_atom_idxs, dtype=int) - 1] assert xyz.base is None # owns the memory if ref not in ('reset', 'retain'): arr = np.array(ref) num_atoms = len(ct_atom_idxs) if arr.size < num_atoms: logger.warning( "WARNING: restrain reference array is too short, " "using existing positions for remaining atoms.") xyz.flat[:arr.size] = arr[:3 * num_atoms] to_merge = make_table(atom_ids=list( zip(itertools.repeat(ct_idx), ct_atom_idxs)), k=fc, ref=xyz) if ref == 'retain': to_merge.copyRefs(table) table.incorporate(to_merge) def _addOneTerm(self, table_name, atoms_selected, res_spec, fc_keys, ref_keys, term_props): """ Add a single term, if reference is not provided in the res_spec, measure from the coordinates. For alchemical terms, prestored reference coordinates are used other than the corrent coordinates to compute term reference values if possible. :param table_name: name of the term table :type table_name: `str` :param atoms_selected: atoms in the term, each atom is (ct_number, `structure.Atom`) pair :type atoms_selected: `tuple` of (`int`, `structure.Atom`)'s :param res_spec: parameters for the term :type res_spec: `Sea.Map` object :param fc_keys: force constant keys :type fc_keys: `list` of `str` :param ref_keys: keys of reference distance, angle and dihedral :type ref_keys: `list` of `str` :param term_props: other term properties that define the force field e.g. schedule for alchemical terms :type term_props: `list` of `str` """ term_atoms = [ AtomID(ct_idx, atom.index) for ct_idx, atom in atoms_selected ] params_dict = { k: v.val for k, v in zip(fc_keys, res_spec.force_constants) } for k in term_props: params_dict[k] = res_spec[k].val for k in ref_keys: if k in res_spec: params_dict[k] = res_spec[k].val else: cts, atoms = list(zip(*atoms_selected)) # alchemical potential, need to measure from reference coordinates if self._restrain.getTable(table_name).is_alchemical: coords = (self._ref_coords[ct_idx][atom.index - 1] for ct_idx, atom in atoms_selected) params_dict[k] = _measure_ref(tuple(coords)) else: params_dict[k] = _measure_ref(list(atoms)) self._restrain.addTerm(table_name, term_atoms, params_dict, persistent=self._persistent) def _addTermFromGenerator(self, table_name, n_atoms, res_spec, fc_keys, ref_keys, term_props): """ Generate terms form the atom selections. :param table_name: name of the term table :type table_name: `str` :param n_atoms: number of atoms in each term :type n_atoms: `int` :param res_spec: parameters to generate terms :type res_spec: `sea.Map` object :param fc_keys: force constant keys :type fc_keys: `list` of `str` :param ref_keys: keys of reference distance, angle and dihedral :type ref_keys: `list` of `str` :param term_props: other term properties that define the force field e.g. schedule for alchemical terms :type term_props: `list` of `str` """ if res_spec.generator.val == GeneratorType.PRODUCT.value: assert n_atoms == len(res_spec.atoms) atoms_selected = [ _select_atoms_for_generator(self._cms_sys, sel) for sel in res_spec.atoms ] for prod_atoms in itertools.product(*atoms_selected): self._addOneTerm(table_name, prod_atoms, res_spec, fc_keys, ref_keys, term_props) elif res_spec.generator.val == GeneratorType.CONNECTED.value: atoms_selected = [] for ct_idx, atoms in enumerate(self._cms_sys.select_atom_comp( res_spec.atoms[0].val), start=2): #filter out empty selctions if len(atoms) == 0: continue current_ct = self._cms_sys.comp_ct[ct_idx - 2] for t in _generic_terms(current_ct, atoms, n_atoms): atoms_selected.append([ (ct_idx, current_ct.atom[a]) for a in t ]) for prod_atoms in atoms_selected: self._addOneTerm(table_name, list(prod_atoms), res_spec, fc_keys, ref_keys, term_props)
[docs] def addRestraints(self): """ Add all restraint terms to the cms object passed in the constructor. This should be the only function called to process all the restraints specified """ for r in self._all_restraints: table_name = r.name.val table = self._restrain.getTable(table_name, self._persistent) if not table.is_positional: param_props, term_props = get_table_schema(table_name) fcs, refs = _find_force_constants(param_props) n_atoms = get_natoms_in_term(table_name) if "generator" in r: self._addTermFromGenerator(table_name, n_atoms, r, fcs, refs, term_props) else: atoms_selected = _select_atoms_single_term( self._cms_sys, r.atoms, n_atoms) self._addOneTerm(table_name, atoms_selected, r, fcs, refs, term_props) else: self._addAllPosreTerms(r) set_encoded_restraints(self._cms_sys, self.getEncoded())
[docs] def getEncoded(self): """ Reports restraints built as b64 encoded JSON string. :rtype: `str` """ return b64_encode(self.getJson())
[docs] def getJson(self): """ :rtype: `str` """ return self._restrain.toJson()
[docs] def getString(self, skip_tables=None, **kwargs) -> str: """ :param skip_tables: Skip the listed tables in the result string. """ skip_tables = skip_tables or {} dct = json.loads(self.getJson()) return '\n'.join( f"'{k}':\n{pprint.pformat({k2: v2 for k2, v2 in v.items() if k2 not in skip_tables}, **kwargs)}" for k, v in dct.items())
[docs]def generate_conf_pose_restraints(cts, ct_numbers, enable_pose_restraint=False, pose_restraint_cfg=None, pose_restraint_terms=None, enable_conf_restraint=False, conf_restraint_cfg=None, conf_restraint_terms=None): """ Generate pose and conf restraints according to cfg. :param cts: tuple of reference and mutant structures :type cts: Tuple(structure.Structure, structure.Structure) :param ct_numbers: tuple of reference and mutant ct numbers in the original Maestro file :type ct_numbers: Tuple(int, int) :param enable_pose_restraint: flag for ligand pose restraints :type enable_pose_restraint: bool :param pose_restraint_cfg: ligand pose specification :type pose_restraint_cfg: Dict :param pose_restraint_terms: dihedrals need to be restrained :type pose_restraint_terms: AlchemicalInteractions.pose_restraint :param enable_conf_restraint: flag for ligand pose restraints :type enable_conf_restraint: bool :param conf_restraint_cfg: conformation specification :type conf_restraint_cfg: Dict :param conf_restraint_terms: dihedrals need to be restrained :type conf_restraint_terms: AlchemicalInteractions.conf_restraint :rtype: restraint.Restraints """ from schrodinger.application.desmond.struc import \ get_atom_reference_coordinates _NAME = 'name' def _restraint_params(ct, ct_num, atoms, param_spec, is_wt=True, restraint_type=None): def _get_AB_keys(key: str) -> str: return key + 'A', key + 'B' natoms = len(atoms) _STRETCH_REF_KEY = 'r0' _DIHEDRAL_REF_KEY = 'phi0' _REF_PREFIX = {2: _STRETCH_REF_KEY, 4: _DIHEDRAL_REF_KEY} _FCA, _FCB = _get_AB_keys(_FC) _SIGMA_A, _SIGMA_B = _get_AB_keys(_SIGMA) _SCHEDULE = 'schedule' _SOFT = 'soft' _ALPHA = 'alpha' _ALPHA_A, _ALPHA_B = _get_AB_keys(_ALPHA) _REF_KEY_A, _REF_KEY_B = _get_AB_keys(_REF_PREFIX[natoms]) ref = _MEASURE_FUNCTIONS[natoms](*[ get_atom_reference_coordinates(ct.atom[atoms[a]]) for a in range(natoms) ]) param_dict = { _FCA: 0.0, _FCB: 0.0, _REF_KEY_A: ref, _REF_KEY_B: ref, _SCHEDULE: param_spec[_SCHEDULE] } if is_wt: param_dict[_FCB] = param_spec[_FC] else: param_dict[_FCA] = param_spec[_FC] if param_spec[_NAME] == _FBHW: param_dict[_SIGMA_A] = param_spec[_SIGMA] param_dict[_SIGMA_B] = param_spec[_SIGMA] if restraint_type == constants.ConfRestraintType.CALPHA_RUNG or \ param_spec[_NAME] == _SOFT: param_dict[_ALPHA_A] = param_spec[_ALPHA] param_dict[_ALPHA_B] = param_spec[_ALPHA] term_atoms = tuple(AtomID(ct_num, atom) for atom in atoms) return (term_atoms, param_dict) _PACKED_ARGUMENTS = list( zip([True, False], constants.FEP_STATE_KEYS, cts, ct_numbers)) _ALCHEMICAL_IMPROPER = 'alchemical_improper_' _ALCHEMICAL_SOFTSTRETCH = 'alchemical_softstretch_' all_restraints = Restraints() if enable_pose_restraint: table_name = _ALCHEMICAL_IMPROPER + pose_restraint_cfg[_NAME] for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS: for dihe in pose_restraint_terms[key]: term_atoms, param_dict = _restraint_params(ct, ct_num, dihe, pose_restraint_cfg, is_wt=is_wt) all_restraints.addTerm(table_name, term_atoms, param_dict) if enable_conf_restraint: table_name_prefix = { constants.ConfRestraintType.BACKBONE: _ALCHEMICAL_IMPROPER, constants.ConfRestraintType.SIDECHAIN: _ALCHEMICAL_IMPROPER, constants.ConfRestraintType.CALPHA_RUNG: _ALCHEMICAL_SOFTSTRETCH } for r in constants.ConfRestraintType: conf_spec = conf_restraint_cfg[r] table_name = table_name_prefix[r] + conf_spec[_NAME] for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS: for atoms in conf_restraint_terms[key][r]: term_atoms, param_dict = _restraint_params(ct, ct_num, atoms, conf_spec, is_wt=is_wt, restraint_type=r) all_restraints.addTerm(table_name, term_atoms, param_dict) return all_restraints
[docs]def has_positional_restraints(model: cms.Cms) -> bool: """ Return True if the cms model has positional restraints. """ # First check for ffio_restraints defined on the cms model for r in model.get_restrain(): if set(r.keys()).intersection( [constants.RestrainTypes.POS, constants.RestrainTypes.POS_FBHW]): return True # Then check new restraints encoded = get_encoded_restraints(model) if encoded: r = Restraints(text=b64_decode(encoded)) return r.has_positional return False