"""
Restraint generation for cross-CT terms and all terms supported
by desmond backend, including alchemical terms.
Either a single term or a generator can be used.
For single term, each selection corresponds to a single atom.
Two kinds of generators are implemented now:
product: The product of all selections is used to generate all
the terms. Use case is to keep alchemical ions way from
places they may get stuck.
connected: One selection is evaluated to generate terms for bond, angle
and torsion. Use case is the alchemical restraints
on protein conformations.
Reference distance, angle and torsion values are computed for generated
terms. For alchemical terms, reference coordinates saved previously will
be used for these calculations if available.
Copyright Schrodinger, LLC. All rights reserved.
"""
import copy
import dataclasses
import enum
import functools
import itertools
import json
import pprint
from typing import List
from typing import Tuple
import numpy as np
from schrodinger.application.desmond import cms
from schrodinger.application.desmond import constants
from schrodinger.application.desmond.packages import msys
from schrodinger.structutils import analyze
from schrodinger.structutils import measure
from schrodinger.utils import log
from schrodinger.utils import sea
from .utils import PERSISTENT_RESTRAINT_KEY
from .utils import RESTRAINT_KEY
from .utils import b64_decode
from .utils import b64_encode
from .utils import get_encoded_restraints
from .utils import set_encoded_restraints
logger = log.get_logger(name="restraint_builder")
_ATOMS_KEY = "atoms"
_TERM_ITERATORS = {
2: analyze.bond_iterator,
3: analyze.angle_iterator,
4: analyze.torsion_iterator,
}
_MEASURE_FUNCTIONS = {
2: measure.measure_distance,
3: measure.measure_bond_angle,
4: measure.measure_dihedral_angle,
}
_FBHW = 'fbhw'
_FC = 'fc'
_POSRE_HARM = 'posre_harm'
_POSRE_FBHW = 'posre_fbhw'
_SIGMA = 'sigma'
_POSRE_HARM_OR_FBHW = frozenset({_POSRE_HARM, _POSRE_FBHW})
[docs]class GeneratorType(enum.Enum):
PRODUCT = 'product'
CONNECTED = 'connected'
_dummy_msys = msys.CreateSystem() # empty msys to get table schemas
__all__ = [
'AtomID',
'Restraints',
'RestraintBuilder',
'generate_conf_pose_restraints',
'has_positional_restraints',
]
[docs]def decode_atom_ids(encoded_ids):
return map(AtomID.make, encoded_ids)
[docs]def get_natoms_in_term(table_name: str) -> int:
"""
Returns arity (number of atoms for each term) for the with
name `table_name`.
:param table_name: name of the desmond term table
:return: number of atoms for each term
"""
# addTableFromSchema only generate a new table once
# if the named table already exists, a pointer will be returned
# c.f. http://opengrok/xref/desmond-gpu-src/other/msys/python/__init__.py#1277
table = _dummy_msys.addTableFromSchema(table_name)
return table.natoms
[docs]def get_table_schema(table_name: str):
"""
Returns schema for desmond term table with name `table_name`.
:param table_name: name of the desmond term table
:return: (param_props, table_props) where both param_props
and table_props are frozensets of property name strings
:rtype: `tuple(frozenset(str), frozenset(str))`
"""
# the following only adds table when named table does not exist
# see link in the above function
table = _dummy_msys.addTableFromSchema(table_name)
term_props = (k for k in table.term_props if k != 'constrained')
return (frozenset(table.params.props), frozenset(term_props))
[docs]@dataclasses.dataclass
class AtomID:
"""
Atom is specified by two numbers: ct index and atom index
within this ct; ct indices starts from 0; ct == 0 indicates that
the atom number is a gid used by desmond backend; ct == 1
is the "full system"; ct >= 2 correspond to the component cts
"""
ct: int
atom: int
[docs] @staticmethod
def make(v):
return AtomID(
**v) if isinstance(v, dict) else AtomID(ct=v[0], atom=v[1])
class _GenericParams(dict):
def __init__(self, *, table_name: str, dct=None):
"""
:param table_name: desmond term table name
:type table_name: str
:param dct: dictionary deserialized from JSON
:type dct: dict
"""
self._is_alchemical = table_name.startswith('alchemical_')
props = get_table_schema(table_name)
for p in sorted(set([_ATOMS_KEY]) | props[0] | props[1]):
self[p] = []
if dct:
for k, v in dct.items():
if k == _ATOMS_KEY:
self[k] += [list(decode_atom_ids(atoms)) for atoms in v]
else:
self[k] += v
def addTerm(self, atoms: List[AtomID], props: dict) -> int:
"""
props contain the actual force-field parameters, force constants,
equilibrium angles and other parameters that specific to the term,
e.g. schedule for alchemical terms
:param atoms: list of atom ids
:param props: dictionary of parameters keyed by name of the
parameter (`str`), including both table properties
(e.g. schedule) and force field parameter properties
:return: index of the term added
"""
atoms_list = self[_ATOMS_KEY]
atoms_list.append(atoms)
for k, v in props.items():
self[k].append(v)
return len(atoms_list) - 1
@property
def is_positional(self):
return False
@property
def is_alchemical(self):
return self._is_alchemical
def incorporate(self, other: '_GenericParams'):
"""
merges `other` into `self`
"""
assert self.keys() == other.keys()
if self[_ATOMS_KEY] and other[_ATOMS_KEY]:
raise NotImplementedError(
"merging is supported for positional restraints only")
elif other[_ATOMS_KEY]:
for k, v in self.items():
v += copy.deepcopy(other[k])
class _PosreHarmParams:
"""
Stores group of position restraints as a numpy structured array
for faster creation (if from existing arrays) and merging.
"""
AID_TYPE = [('ct', 'i8'), ('atom', 'i8')]
DTYPE = [("atom_id", AID_TYPE), ("k", "f8", (3,)), ("ref", "f8", (3,))]
INTERSECT_KEYS = "atom_id"
def __init__(self, *, atom_ids=None, k=None, ref=None, dct=None, **kwargs):
"""
:param atom_ids: list of lists of AtomIDs
:param k: force constants
:param ref: reference positions
:param dct: dictionary de-serialized from json; if truthy,
overrides other arguments
"""
if dct:
atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY])
ref = list(zip(dct['x0'], dct['y0'], dct['z0']))
k = list(zip(dct['fcx'], dct['fcy'], dct['fcz']))
num_entries = len(atom_ids or [])
self.arr = np.empty(shape=num_entries, dtype=self.DTYPE)
if num_entries:
self.arr['atom_id'] = atom_ids
self.arr['ref'] = ref
self.arr['k'] = k
@staticmethod
def _transcode_atom_ids(list_of_encoded_aid_lists):
"""
Helper function that takes list-of-lists of atom IDs encoded
either as (ct, atom) pairs or {'ct': ct, 'atom': atom} dicts
and returns list of (ct, atom) tuples (see c-tor).
"""
list_of_tuples = []
for encoded_aids in list_of_encoded_aid_lists:
aid = next(decode_atom_ids(encoded_aids))
list_of_tuples.append(dataclasses.astuple(aid))
return list_of_tuples
def _intersect_with(self, other) -> (np.array, np.array):
"""
Find the duplicate entries in self and other and return their indices,
respectively
"""
_, self_idx, other_idx = np.intersect1d(self.arr[self.INTERSECT_KEYS],
other.arr[self.INTERSECT_KEYS],
assume_unique=True,
return_indices=True)
return self_idx, other_idx
def incorporate(self, other: '_PosreHarmParams'):
"""
merges `other` into `self`
"""
assert isinstance(other, _PosreHarmParams)
self_idx, other_idx = self._intersect_with(other)
other_k = other.arr['k'][other_idx]
self.arr['k'][self_idx] = np.maximum(self.arr['k'][self_idx], other_k)
self.arr['ref'][self_idx] = other.arr['ref'][other_idx]
mask = np.ones(other.arr.size, dtype=bool)
mask[other_idx] = False
self.arr = np.concatenate([self.arr, other.arr[mask]])
def copyRefs(self, other: '_PosreHarmParams'):
"""
copies reference positions from the `other`
"""
assert isinstance(other, _PosreHarmParams)
self_idx, other_idx = self._intersect_with(other)
self.arr['ref'][self_idx] = other.arr['ref'][other_idx]
@property
def atom_ids_as_tuples(self):
return [
[(int(aid['ct']), int(aid['atom']))] for aid in self.arr['atom_id']
]
@property
def refs(self):
for a, v in zip('xyz', self.arr['ref'].transpose().tolist()):
yield f'{a}0', v
def asdict(self):
outcome = {_ATOMS_KEY: self.atom_ids_as_tuples}
outcome.update(self.refs)
for a, v in zip('xyz', self.arr['k'].transpose().tolist()):
outcome[f'{_FC}{a}'] = v
return outcome
def __len__(self):
return len(self.arr)
@property
def is_positional(self):
return True
@property
def is_alchemical(self):
return False
class _PosreFBHWParams(_PosreHarmParams):
DTYPE = [("atom_id", _PosreHarmParams.AID_TYPE), ("k", "f8"),
("ref", "f8", (3,)), (_SIGMA, "f8")]
INTERSECT_KEYS = ["atom_id", _SIGMA]
def __init__(self,
*,
atom_ids=None,
k=None,
ref=None,
sigma=None,
dct=None,
**kwargs):
if dct:
atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY])
ref = list(zip(dct['x0'], dct['y0'], dct['z0']))
k = dct[_FC]
sigma = dct[_SIGMA]
super().__init__(atom_ids=atom_ids, k=k, ref=ref)
if sigma:
# will broadcast as needed
self.arr[_SIGMA] = sigma
def asdict(self):
outcome = {
_ATOMS_KEY: self.atom_ids_as_tuples,
_FC: self.arr['k'].tolist(),
_SIGMA: self.arr[_SIGMA].tolist()
}
outcome.update(self.refs)
return outcome
class _JsonEncoder(json.JSONEncoder):
@functools.singledispatchmethod
def default(self, obj):
return super().default(obj)
@default.register
def _(self, obj: AtomID):
return dataclasses.astuple(obj)
@default.register
def _(self, obj: _PosreHarmParams):
return obj.asdict()
[docs]class Restraints:
"""
Holds restraint terms parameters. Assumes that "persistent"
(aka "permanent") tables support merging (only "posre_*" at the moment).
"""
[docs] def __init__(self, *, text=None):
"""
:param text: pre-existing serialized restraints (as json) to build upon
:type text: str or NoneType
"""
data = {
RESTRAINT_KEY: {},
PERSISTENT_RESTRAINT_KEY: {}
} if text is None else json.loads(text)
def instantiate(d):
return {
name: self._getTableClass(name)(table_name=name, dct=dct)
for name, dct in d.items()
}
self._restraints = instantiate(data.get(RESTRAINT_KEY, {}))
self._persistent = instantiate(data.get(PERSISTENT_RESTRAINT_KEY, {}))
@staticmethod
def _getTableClass(table_name):
return {
_POSRE_HARM: _PosreHarmParams,
_POSRE_FBHW: _PosreFBHWParams,
}.get(table_name, _GenericParams)
[docs] def getTable(self, table_name: str, persistent: bool = False) -> object:
"""
Gets parameters table by name.
:param table_name: name of the desmond term table
:param persistent: persistent vs regular table
:return: requested table
:rtype: `_GenericParams` or `_PosreHarmParams` or `_PosreFBHWParams`
"""
group = self._persistent if persistent else self._restraints
try:
return group[table_name]
except KeyError:
table_class = self._getTableClass(table_name)
table = table_class(table_name=table_name)
group[table_name] = table
return table
[docs] def addTerm(self,
table_name: str,
atoms: List[AtomID],
props: dict,
persistent: bool = False) -> int:
"""
Adds single restrain term.
table_name is the desmond interaction table, stretch_harm,
alchemical_improper_harm etc.
An atom is specified by two numbers, ct number and atom number
in ct. ct number starts from 0, that means the atom number is
gid used by desmond backend. ct number 1 means full system.
ct numbers greater than or equal to 2 mean component cts.
props contain the actual force-field parameters, force constants,
equilibrium angles and other parameters that specific to the term,
e.g. schedule for alchemical terms
:param table_name: name of the table, this is one of the term
tables supported by desmond
:param atoms: atom ids
:param props: dictionary of parameter keyed by name of the
parameter (`str`), including both table properties
(e.g. schedule) and force field parameter properties.
:param persistent: is this a persistent term
:return: index of the term added
"""
table = self.getTable(table_name, persistent)
return table.addTerm(atoms, props)
@property
def has_persistent(self) -> bool:
return bool(self._persistent)
[docs] def toJson(self) -> str:
"""
:return: json string to be loaded by msys
"""
if self.has_persistent:
restraints = copy.deepcopy(self._restraints)
for name, persistent_table in self._persistent.items():
if table := restraints.get(name):
table.incorporate(persistent_table)
restraints[name] = table
else:
restraints[name] = persistent_table
dct = {
RESTRAINT_KEY: restraints,
PERSISTENT_RESTRAINT_KEY: self._persistent
}
else:
dct = {RESTRAINT_KEY: self._restraints}
return json.dumps(dct, cls=_JsonEncoder)
@property
def has_positional(self) -> bool:
"""
Set to True if the Restraint has any positional restraints.
"""
return any(table.is_positional for table in itertools.chain(
self._restraints.values(), self._persistent.values()))
def _measure_ref(atoms):
"""
Measure wrapper for distance, angle and torsion.
:param atoms: list/tuple of atoms
:type atoms: list of `structure.Atom` or list of 3 floats
:rtype: float
"""
n_atom = len(atoms)
func = _MEASURE_FUNCTIONS.get(n_atom)
if func:
return func(*atoms)
else:
raise ValueError("Cannot measure on %d number of atoms" % n_atom)
def _generic_terms(struct, atom_selection, n_atom):
"""
Enumerate terms from atom selection according to number of atoms
n_atom 2: bond, 3: angle, 4: dihedral.
:param struct: structure for connection
:type struct: `structure.Structure`
:param atom_selection: selected atoms
:type atom_selection: list of atom indices
:param n_atom: number of atoms in a term
:type n_atom: integer value of 2, 3 or 4
:rtype: list of list
"""
it = _TERM_ITERATORS.get(n_atom)
if it:
return it(struct=struct, atoms=atom_selection)
else:
raise ValueError("Unknown term using %d atoms" % n_atom)
def _select_atoms_single_term(cms_mol, atom_sel, n_atoms):
"""
Select atoms for a single term, only the first atom selected
in each CT will be used. Atoms returned are in the form of
(CT number, `structure.Atom`)
:param cms_mol: input system for selection
:type cms_mol: `cms.Cms` object
:param atom_sel: atom selection specification
:type atom_sel: `sea.List` object
:param n_atoms: number of atoms in term
:type n_atoms: `int`
:rtype: list of atom tuples, (`int`, `structure.Atom`)
"""
atoms_ret = []
for i in range(n_atoms):
for ct_idx, atoms in enumerate(cms_mol.select_atom_comp(
atom_sel[i].val),
start=2):
if len(atoms) == 1:
atoms_ret.append(
(ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[atoms[0]]))
break
elif len(atoms) > 1:
raise ValueError(
"More than 1 atom selected for the term, selection: %s" %
atom_sel[i].val)
if len(atoms_ret) == n_atoms:
return atoms_ret
else:
raise ValueError(
"Selected atoms do not match required %d versus %d, selection: %s\n"
% (len(atoms_ret), n_atoms, atom_sel))
def _select_atoms_for_generator(cms_mol, atom_sel):
"""
Select atoms to be used for actual term generation.
Atoms retuned are in (CT number, `structure.Atom`) form.
:param cms_mol: input system for selection
:type cms_mol: `cms.Cms`
:param atom_sel: atom selection specification, ASL
:type atom_sel: `sea.Map`
:rtype: list of atom tuples, (`int`, `structure.Atom`)
"""
atoms_ret = []
for ct_idx, atoms in enumerate(cms_mol.select_atom_comp(atom_sel.val),
start=2):
for a in atoms:
atoms_ret.append((ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[a]))
return atoms_ret
def _find_force_constants(param_props: List[str]) -> Tuple[List[str]]:
"""
Find all force constant keys in parameter properties and
return tuple of force constants and other properties.
:param param_props: force field parameters
"""
fcs = []
ref = []
for p in param_props:
if p.startswith(_FC):
fcs.append(p)
else:
ref.append(p)
return (sorted(fcs), sorted(ref))
[docs]class RestraintBuilder:
[docs] def __init__(self,
restraint_terms: sea.List,
existing: constants.EXISTING_RESTRAINT,
cms_sys: cms.Cms,
persistent: bool = False):
"""
:param restraint_terms: all restraint terms to be added
:param existing: One of `constants.EXISTING_RESTRAINT`, determines
whether to `IGNORE` current restraints and replace them with
`restraint_terms` or `RETAIN` them and update them with
`restraint_terms`.
:param cms_sys: cms object for molecules
:param persistent: build "persistent" restraints
"""
self._cms_sys = cms_sys
self._all_restraints = restraint_terms
self._persistent = persistent
encoded = get_encoded_restraints(cms_sys)
if encoded:
self._restrain = Restraints(text=self._applyRestraintsDisposition(
b64_decode(encoded), existing))
else:
self._restrain = Restraints()
fep_cts = cms_sys.get_fep_cts()
coords = (None, None)
if all(fep_cts):
try:
from schrodinger.application.scisol.packages.core_hopping.int_fepio import \
get_reference_coordinates_for_two_molecules
wt_ct, mut_ct = fep_cts
coords = get_reference_coordinates_for_two_molecules(
wt_ct, mut_ct)
if wt_ct is not None and mut_ct is not None:
logger.debug(
"Reference coordinates found, using them for alchemical terms."
)
except ImportError:
logger.debug(
"Cannot import scisol function, use current coordinates.")
except RuntimeError:
# get_reference_coorrdinates_for_two_molecules throws RuntimeError
pass
fep_ref_coord = {k: v for k, v in zip(fep_cts, coords) if v is not None}
self._ref_coords = {
ct_idx: fep_ref_coord[ct] if ct in fep_ref_coord else ct.getXYZ()
for ct_idx, ct in enumerate(cms_sys.comp_ct, start=2)
}
def _applyRestraintsDisposition(
self, text: str, existing: constants.EXISTING_RESTRAINT) -> str:
"""
:param text: Serialized restraints (see `Restraints.toJson()`).
:param existing: Restraints disposition (ignore/retain/ignore_posre).
:return: Serialized restraints ready for `Restraints` constructor.
"""
retain_all = existing == constants.EXISTING_RESTRAINT.RETAIN
ignore_posre = existing == constants.EXISTING_RESTRAINT.IGNORE_POSRE
dct = json.loads(text)
stage_restraints = dct.get(RESTRAINT_KEY, {})
persistent = dct.get(PERSISTENT_RESTRAINT_KEY, {})
group = persistent if self._persistent else stage_restraints
if retain_all or ignore_posre:
if ignore_posre:
for k in _POSRE_HARM_OR_FBHW:
group.pop(k, None)
else:
group.clear()
return json.dumps({
RESTRAINT_KEY: stage_restraints,
PERSISTENT_RESTRAINT_KEY: persistent
})
def _addAllPosreTerms(self, spec: 'sea.Map'):
"""
Add positional restraint terms as directed by the `spec`.
:param spec: restraint specs for posre_harm or posre_fbhw
"""
fc = spec.force_constants.val
ref = spec.ref.val if 'ref' in spec else 'reset'
model = self._cms_sys
atom_idxs = model.select_atom_comp(spec.atoms.val if isinstance(
spec.atoms, sea.Atom) else spec.atoms[0].val)
table_name = spec.name.val
table = self._restrain.getTable(table_name, self._persistent)
make_table = functools.partial(
_PosreFBHWParams,
sigma=spec.sigma.val) if _FBHW in table_name else _PosreHarmParams
for ct_idx, (ct, ct_atom_idxs) in enumerate(zip(model.comp_ct,
atom_idxs),
start=2):
xyz = ct.getXYZ(copy=False)[np.array(ct_atom_idxs, dtype=int) - 1]
assert xyz.base is None # owns the memory
if ref not in ('reset', 'retain'):
arr = np.array(ref)
num_atoms = len(ct_atom_idxs)
if arr.size < num_atoms:
logger.warning(
"WARNING: restrain reference array is too short, "
"using existing positions for remaining atoms.")
xyz.flat[:arr.size] = arr[:3 * num_atoms]
to_merge = make_table(atom_ids=list(
zip(itertools.repeat(ct_idx), ct_atom_idxs)),
k=fc,
ref=xyz)
if ref == 'retain':
to_merge.copyRefs(table)
table.incorporate(to_merge)
def _addOneTerm(self, table_name, atoms_selected, res_spec, fc_keys,
ref_keys, term_props):
"""
Add a single term, if reference is not provided in the res_spec,
measure from the coordinates. For alchemical terms, prestored
reference coordinates are used other than the corrent coordinates
to compute term reference values if possible.
:param table_name: name of the term table
:type table_name: `str`
:param atoms_selected: atoms in the term, each atom is
(ct_number, `structure.Atom`) pair
:type atoms_selected: `tuple` of (`int`, `structure.Atom`)'s
:param res_spec: parameters for the term
:type res_spec: `Sea.Map` object
:param fc_keys: force constant keys
:type fc_keys: `list` of `str`
:param ref_keys: keys of reference distance, angle and dihedral
:type ref_keys: `list` of `str`
:param term_props: other term properties that define the force field
e.g. schedule for alchemical terms
:type term_props: `list` of `str`
"""
term_atoms = [
AtomID(ct_idx, atom.index) for ct_idx, atom in atoms_selected
]
params_dict = {
k: v.val for k, v in zip(fc_keys, res_spec.force_constants)
}
for k in term_props:
params_dict[k] = res_spec[k].val
for k in ref_keys:
if k in res_spec:
params_dict[k] = res_spec[k].val
else:
cts, atoms = list(zip(*atoms_selected))
# alchemical potential, need to measure from reference coordinates
if self._restrain.getTable(table_name).is_alchemical:
coords = (self._ref_coords[ct_idx][atom.index - 1]
for ct_idx, atom in atoms_selected)
params_dict[k] = _measure_ref(tuple(coords))
else:
params_dict[k] = _measure_ref(list(atoms))
self._restrain.addTerm(table_name,
term_atoms,
params_dict,
persistent=self._persistent)
def _addTermFromGenerator(self, table_name, n_atoms, res_spec, fc_keys,
ref_keys, term_props):
"""
Generate terms form the atom selections.
:param table_name: name of the term table
:type table_name: `str`
:param n_atoms: number of atoms in each term
:type n_atoms: `int`
:param res_spec: parameters to generate terms
:type res_spec: `sea.Map` object
:param fc_keys: force constant keys
:type fc_keys: `list` of `str`
:param ref_keys: keys of reference distance, angle and dihedral
:type ref_keys: `list` of `str`
:param term_props: other term properties that define the force field
e.g. schedule for alchemical terms
:type term_props: `list` of `str`
"""
if res_spec.generator.val == GeneratorType.PRODUCT.value:
assert n_atoms == len(res_spec.atoms)
atoms_selected = [
_select_atoms_for_generator(self._cms_sys, sel)
for sel in res_spec.atoms
]
for prod_atoms in itertools.product(*atoms_selected):
self._addOneTerm(table_name, prod_atoms, res_spec, fc_keys,
ref_keys, term_props)
elif res_spec.generator.val == GeneratorType.CONNECTED.value:
atoms_selected = []
for ct_idx, atoms in enumerate(self._cms_sys.select_atom_comp(
res_spec.atoms[0].val),
start=2):
#filter out empty selctions
if len(atoms) == 0:
continue
current_ct = self._cms_sys.comp_ct[ct_idx - 2]
for t in _generic_terms(current_ct, atoms, n_atoms):
atoms_selected.append([
(ct_idx, current_ct.atom[a]) for a in t
])
for prod_atoms in atoms_selected:
self._addOneTerm(table_name, list(prod_atoms), res_spec,
fc_keys, ref_keys, term_props)
[docs] def addRestraints(self):
"""
Add all restraint terms to the cms object passed in the constructor.
This should be the only function called to process all the restraints specified
"""
for r in self._all_restraints:
table_name = r.name.val
table = self._restrain.getTable(table_name, self._persistent)
if not table.is_positional:
param_props, term_props = get_table_schema(table_name)
fcs, refs = _find_force_constants(param_props)
n_atoms = get_natoms_in_term(table_name)
if "generator" in r:
self._addTermFromGenerator(table_name, n_atoms, r, fcs,
refs, term_props)
else:
atoms_selected = _select_atoms_single_term(
self._cms_sys, r.atoms, n_atoms)
self._addOneTerm(table_name, atoms_selected, r, fcs, refs,
term_props)
else:
self._addAllPosreTerms(r)
set_encoded_restraints(self._cms_sys, self.getEncoded())
[docs] def getEncoded(self):
"""
Reports restraints built as b64 encoded JSON string.
:rtype: `str`
"""
return b64_encode(self.getJson())
[docs] def getJson(self):
"""
:rtype: `str`
"""
return self._restrain.toJson()
[docs] def getString(self, skip_tables=None, **kwargs) -> str:
"""
:param skip_tables: Skip the listed tables in the result string.
"""
skip_tables = skip_tables or {}
dct = json.loads(self.getJson())
return '\n'.join(
f"'{k}':\n{pprint.pformat({k2: v2 for k2, v2 in v.items() if k2 not in skip_tables}, **kwargs)}"
for k, v in dct.items())
[docs]def generate_conf_pose_restraints(cts,
ct_numbers,
enable_pose_restraint=False,
pose_restraint_cfg=None,
pose_restraint_terms=None,
enable_conf_restraint=False,
conf_restraint_cfg=None,
conf_restraint_terms=None):
"""
Generate pose and conf restraints according to cfg.
:param cts: tuple of reference and mutant structures
:type cts: Tuple(structure.Structure, structure.Structure)
:param ct_numbers: tuple of reference and mutant ct numbers
in the original Maestro file
:type ct_numbers: Tuple(int, int)
:param enable_pose_restraint: flag for ligand pose restraints
:type enable_pose_restraint: bool
:param pose_restraint_cfg: ligand pose specification
:type pose_restraint_cfg: Dict
:param pose_restraint_terms: dihedrals need to be restrained
:type pose_restraint_terms: AlchemicalInteractions.pose_restraint
:param enable_conf_restraint: flag for ligand pose restraints
:type enable_conf_restraint: bool
:param conf_restraint_cfg: conformation specification
:type conf_restraint_cfg: Dict
:param conf_restraint_terms: dihedrals need to be restrained
:type conf_restraint_terms: AlchemicalInteractions.conf_restraint
:rtype: restraint.Restraints
"""
from schrodinger.application.desmond.struc import \
get_atom_reference_coordinates
_NAME = 'name'
def _restraint_params(ct,
ct_num,
atoms,
param_spec,
is_wt=True,
restraint_type=None):
def _get_AB_keys(key: str) -> str:
return key + 'A', key + 'B'
natoms = len(atoms)
_STRETCH_REF_KEY = 'r0'
_DIHEDRAL_REF_KEY = 'phi0'
_REF_PREFIX = {2: _STRETCH_REF_KEY, 4: _DIHEDRAL_REF_KEY}
_FCA, _FCB = _get_AB_keys(_FC)
_SIGMA_A, _SIGMA_B = _get_AB_keys(_SIGMA)
_SCHEDULE = 'schedule'
_SOFT = 'soft'
_ALPHA = 'alpha'
_ALPHA_A, _ALPHA_B = _get_AB_keys(_ALPHA)
_REF_KEY_A, _REF_KEY_B = _get_AB_keys(_REF_PREFIX[natoms])
ref = _MEASURE_FUNCTIONS[natoms](*[
get_atom_reference_coordinates(ct.atom[atoms[a]])
for a in range(natoms)
])
param_dict = {
_FCA: 0.0,
_FCB: 0.0,
_REF_KEY_A: ref,
_REF_KEY_B: ref,
_SCHEDULE: param_spec[_SCHEDULE]
}
if is_wt:
param_dict[_FCB] = param_spec[_FC]
else:
param_dict[_FCA] = param_spec[_FC]
if param_spec[_NAME] == _FBHW:
param_dict[_SIGMA_A] = param_spec[_SIGMA]
param_dict[_SIGMA_B] = param_spec[_SIGMA]
if restraint_type == constants.ConfRestraintType.CALPHA_RUNG or \
param_spec[_NAME] == _SOFT:
param_dict[_ALPHA_A] = param_spec[_ALPHA]
param_dict[_ALPHA_B] = param_spec[_ALPHA]
term_atoms = tuple(AtomID(ct_num, atom) for atom in atoms)
return (term_atoms, param_dict)
_PACKED_ARGUMENTS = list(
zip([True, False], constants.FEP_STATE_KEYS, cts, ct_numbers))
_ALCHEMICAL_IMPROPER = 'alchemical_improper_'
_ALCHEMICAL_SOFTSTRETCH = 'alchemical_softstretch_'
all_restraints = Restraints()
if enable_pose_restraint:
table_name = _ALCHEMICAL_IMPROPER + pose_restraint_cfg[_NAME]
for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS:
for dihe in pose_restraint_terms[key]:
term_atoms, param_dict = _restraint_params(ct,
ct_num,
dihe,
pose_restraint_cfg,
is_wt=is_wt)
all_restraints.addTerm(table_name, term_atoms, param_dict)
if enable_conf_restraint:
table_name_prefix = {
constants.ConfRestraintType.BACKBONE: _ALCHEMICAL_IMPROPER,
constants.ConfRestraintType.SIDECHAIN: _ALCHEMICAL_IMPROPER,
constants.ConfRestraintType.CALPHA_RUNG: _ALCHEMICAL_SOFTSTRETCH
}
for r in constants.ConfRestraintType:
conf_spec = conf_restraint_cfg[r]
table_name = table_name_prefix[r] + conf_spec[_NAME]
for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS:
for atoms in conf_restraint_terms[key][r]:
term_atoms, param_dict = _restraint_params(ct,
ct_num,
atoms,
conf_spec,
is_wt=is_wt,
restraint_type=r)
all_restraints.addTerm(table_name, term_atoms, param_dict)
return all_restraints
[docs]def has_positional_restraints(model: cms.Cms) -> bool:
"""
Return True if the cms model has positional restraints.
"""
# First check for ffio_restraints defined on the cms model
for r in model.get_restrain():
if set(r.keys()).intersection(
[constants.RestrainTypes.POS, constants.RestrainTypes.POS_FBHW]):
return True
# Then check new restraints
encoded = get_encoded_restraints(model)
if encoded:
r = Restraints(text=b64_decode(encoded))
return r.has_positional
return False