Source code for schrodinger.application.desmond.fep_edge_data_classifier

from enum import Enum
from itertools import chain
from typing import Tuple

import numpy as np


[docs]class Rating(Enum): NA = 0 GOOD = 1 FAIR = 2 BAD = 3
class _Const: GOOD_RMSD = 2.0 # Angstrom FAIR_RMSD = GOOD_RMSD * 2 # Angstrom DG_SLOPE = 0.30 # kcal/mol/ns, allowed change rate in dG DG_CONVERGENCE_TIME_SPAN = 1 # ns, time span for convergence check REST_EX_CUTOFF = 0.15 # # classifier names CONVERGENCE = 'convergence' LIGAND_RMSD = 'ligand RMSD' REST_EXCHANGE = 'REST exchange'
[docs]def rate(name: str, data) -> Rating: """ Return rating for the FEP edge data with the given `name`. The format of `data` varies depending on the case. :param name: Name of the classifier. The allowed names are defined above. """ result = Rating.NA try: result = _CLASSIFIERS[name](data) except Exception as e: # FIXME: is this enough? print(f"ERROR: Failed to classify {name} data for FEP edge: {e}.") return result
def _convergence(data: Tuple[float, float, Tuple]) -> Rating: """ :param data: simulation times and dG values in the complex leg, i.e., [start_time, end_time, dG_values] """ assert len( data[2] ) > 1, f"Cannot determine convergence with {len(data[2])} data points" ts, dt = np.linspace(data[0], data[1], len(data[2]), retstep=True) indices = ts >= data[1] - _Const.DG_CONVERGENCE_TIME_SPAN indices[-2] = True # Two points are needed no matter what check_ts = ts[indices] check_dgs = np.asarray(data[2])[indices] # For the last DG_CONVERGENCE_TIME_SPAN ns, check if maximal dG change is too big. if _Const.DG_SLOPE * (check_ts[-1] - check_ts[0]) < max(check_dgs) - min(check_dgs): return Rating.BAD # For the last DG_CONVERGENCE_TIME_SPAN ns, check if any two consecutive dG # values are close to each other within DG_SLOPE*dt. if all(abs(np.ediff1d(check_dgs)) < _Const.DG_SLOPE * dt): return Rating.GOOD return Rating.FAIR def _RESTExchange(data: Tuple[Tuple[int]]) -> Rating: """ :param data: replica history distributions """ # adapted from Dan Sindhikara's script stddevs = [] for hist in data: hist = np.asarray(hist) stddevs.append(np.std(hist / hist.mean())) if np.mean(stddevs) < _Const.REST_EX_CUTOFF: return Rating.GOOD return Rating.FAIR def _ligandRMSD(data: Tuple[Tuple[float]]) -> Rating: """ :param data: ligand RMSDs in the complex leg, i.e., [lambda0_rmsd, lambda1_rmsd] """ if any(x > _Const.FAIR_RMSD for x in chain.from_iterable(data)): return Rating.BAD if any(x > _Const.GOOD_RMSD for x in chain.from_iterable(data)): return Rating.FAIR return Rating.GOOD _CLASSIFIERS = { LIGAND_RMSD: _ligandRMSD, REST_EXCHANGE: _RESTExchange, CONVERGENCE: _convergence }