Source code for schrodinger.application.desmond.fepana

"""
Tools for various FEP-related analyses.

Copyright Schrodinger, LLC. All rights reserved.

"""

# Contributors: Yujie Wu, Dan Sindhikara

import copy
import math
import os
import re
import sys
from past.utils import old_div
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union

import numpy

from schrodinger.application.desmond import bennett
from schrodinger.application.desmond import cms
from schrodinger.application.desmond import config
from schrodinger.application.desmond import measurement
from schrodinger.application.desmond import util
from schrodinger.application.desmond.constants import BOLTZMANN


# FIXME: Replaces this class with Joe's version.
[docs]class Restraint:
[docs] def __init__(self, atom, ref, k): self.atom = atom self.ref = ref self.k = k
def _float_range(start, stop, step, closed=True): """ Return evenly spaced float values from start to stop. :param start: First value in the range :param stop: Last value in the range if `closed`, first value outside of the range if not `closed`. :param step: The size of the step between elements of the result range. :param closed: Should `stop` be included in the range? """ EPSILON = 1e-9 signed_epsilon = -EPSILON if step < 0 else EPSILON if closed: points = numpy.arange(start, stop + signed_epsilon, step) if points.size > 1 and abs(points[-1] - stop) < EPSILON: points[-1] = stop return points else: return numpy.arange(start, stop - signed_epsilon, step)
[docs]def get_energy_table(fname, term_list): """ columns: (0, 0) (0, 1) (1, 1) ... "total" rows : frame1, frame2, frame3, ... Separate table for each term. table[row][col] """ with open(fname, "r") as fh: s = fh.readlines() table = {term: [] for term in term_list} col_meaning = ["time"] TITLE_PATTERN = re.compile(r"\([^\(\)]+\)") for line in s: word = TITLE_PATTERN.findall(line) if (word != [] and word[0].lower() == "(pair)"): for w in word[2:]: if (w.lower() == "total"): col_meaning.append("total") else: col_meaning.append( tuple([int(e) for e in w[1:-1].split(",")])) break for line in s: word = line.split() for term in table: if (word != [] and word[0] == term): value = [ float(word[1][1:-1]), ] value.extend([float(e) for e in word[2:]]) table[term].append(value) break return ( col_meaning, table, )
[docs]def get_global_quantity(fname, quantity_list): """ """ with open(fname, "r") as fh: s = fh.readlines() table = {quantity: [] for quantity in quantity_list} for line in s: if (line[:4] == "time"): quantity = line.split() for e in quantity: q, v = e.split("=") if (q in table): table[q].append(float(v)) return table
[docs]def get_mean(ene, index=-1, data_structure="table"): """ Returns (mean, std_error, std_dev,). """ ene_average = 0.0 ene_stddev = 0.0 num_data = 0 if (data_structure == "table"): for row in ene: ene_average += row[index] num_data += 1 elif (data_structure == "array"): for e in ene: ene_average += e num_data += 1 else: raise ValueError("Unknown data structure: %s" % data_structure) if (num_data == 0): num_data = 1 ene_average /= num_data if (data_structure == "table"): for row in ene: ene_stddev += (row[index] - ene_average) * (row[index] - ene_average) elif (data_structure == "array"): for e in ene: ene_stddev += (e - ene_average) * (e - ene_average) if (num_data == 1): ene_stddev = float("inf") else: ene_stddev /= num_data - 1 ene_stddev = math.sqrt(ene_stddev) return ene_average, old_div(ene_stddev, math.sqrt(num_data)), ene_stddev
[docs]def parse_eneseq(eneseq_fname): """ Returns a 1D structured array with names equal to the column names :param eneseq_fname: eneseq file name. """ with open(eneseq_fname, "r") as fh: for line in fh.readlines(): line = line.strip() if line and line[0] == "#": split = line.split() if len(split) > 1 and split[1] == '0:time': # col_id:col_name (units) ... header_labels = line.split()[1::2] header_names = [ label.split(":")[1] for label in header_labels ] break else: raise ValueError("Column headers not found") return numpy.atleast_1d(numpy.genfromtxt(eneseq_fname, names=header_names))
[docs]def init_bennett(data: Union[str, numpy.ndarray], n_win=12, temperature=300.0, begin_time=100.0, end_time=-1.0, random_seed=2111839, result_file=None, nresamples=0, file_pattern='gibbs.%d.dE'): """ :param data: Either a directory or a numpy array. As a directory, it must contain the dE files, which are named after the pattern specified by the `file_pattern` argument. As a numpy array, it is the dE data read from the dE files. The data is an MxNx3 array, where M is the number of lambda windows, N the number of time points, and the 3 are (time, forward energy, reverse energy). """ assert isinstance(data, (str, Path, numpy.ndarray)), \ "Must be a numpy array or str/Path" if are_times_insane(begin_time, end_time): print("Warning, BAR calculation initialized with unreasonable \n") print(f"begin and end times: {begin_time}, {end_time}") bar = bennett.CalcBAR(begin_time=begin_time, end_time=end_time, temperature=temperature, seed=random_seed, nresamples=nresamples) if isinstance(data, (str, Path)): if result_file is not None: result_file = os.path.join(data, os.path.basename(result_file)) fns = [os.path.join(data, file_pattern % i) for i in range(0, n_win)] bar.load_data(fns) else: # data is a numpy array bar.dE = data bar.filter_data(begin_time, end_time) bar.set_output(result_file, None, None) return bar
[docs]def run_bennett(bar, begin_time=100.0, end_time=-1.0, nresamples=None): """ """ if nresamples is not None: bar.set_nresamples(nresamples) bar.set_seed(bar.seed) bar.filter_data(begin_time, end_time) try: results = bar.analyze_data() bar.write_results(results) dF = [ measurement.Measurement(a[0], max(a[1], a[2])) for a in results[:-1] ] result = measurement.Measurement(results[-1][0], max(results[-1][1], results[-1][2])) return result, bar.err, dF, results except Exception as e: # Reduce begin_time gradually until there's enough statistics for # BAR calculation to succeed if begin_time > 0: begin_time -= 50.0 if begin_time < 0: begin_time = 0.0 return run_bennett(bar, begin_time, end_time, nresamples) return (None, bar.err + '\n' + repr(e), [], [])
[docs]def are_times_insane(begin_time, end_time): """ Are the given begin and end times reasonable? """ return ((end_time <= begin_time) and end_time != -1.0) or (end_time == 0) or (begin_time < 0)
[docs]def get_delta_time(begin_time, end_time, delta_time, window=0): if window >= end_time - begin_time: # we can only use a single possibly truncated window return 0 if isinstance(delta_time, str): tokens = delta_time.split(":") if tokens[1] == "points": num_point = int(tokens[0]) if window: num_point -= 1 span = end_time - begin_time - window delta_time = span / num_point else: raise ValueError("Wrong syntax: %s" % delta_time) return delta_time
[docs]def calc_free_energy_time_function(dir, last_time, n_win, temperature=300.0, begin_time=100.0, end_time=-1.0, delta_time=30.0, random_seed=2111839): """ Calculates the free energy as a function of time. """ try: end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time) except TimeSanityException: return [], [] stop_times = _float_range(begin_time, end_time, dt, closed=False) + dt start_times = (begin_time,) * len(stop_times) time_ranges = list(zip(start_times, stop_times, stop_times)) return calc_time_curve(dir, n_win, temperature, begin_time, end_time, random_seed, time_ranges)
[docs]def calc_free_energy_rtime_function(dir, last_time, n_win, temperature=300.0, begin_time=100.0, end_time=-1.0, delta_time=30.0, random_seed=2111839): """ Calculates the free energy as a function of reversed time. """ try: end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time) except TimeSanityException: return [], [] label_times = _float_range(begin_time, end_time, dt, closed=False) stop_times = (last_time,) * len(label_times) start_times = last_time - label_times time_ranges = list(zip(start_times, stop_times, label_times)) return calc_time_curve(dir, n_win, temperature, begin_time, end_time, random_seed, time_ranges)
[docs]def calc_free_energy_stime_function(dir, last_time, n_win, temperature=300.0, begin_time=100.0, end_time=-1.0, delta_time=30.0, window=500.0, random_seed=2111839): """ Calculates the free energy as a function of time with sliding window. """ try: end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time, window) except TimeSanityException: return [], [] # cleanup_time will return 0 if window is larger than end_time - begin_time # in that case we use a single possibly truncated window if dt == 0: start_times, stop_times = (begin_time,), (end_time,) else: stop_times = _float_range(begin_time + window, end_time, dt) # assert stop_times.size, "expected _float_range to return at least " \ # "one point" start_times = stop_times - window time_ranges = list(zip(start_times, stop_times, start_times)) return calc_time_curve(dir, n_win, temperature, begin_time, end_time, random_seed, time_ranges)
# backward compatibility calc_freeenergy_time_function = calc_free_energy_time_function calc_freeenergy_rtime_function = calc_free_energy_rtime_function calc_freeenergy_stime_function = calc_free_energy_stime_function
[docs]class TimeSanityException(Exception): pass
[docs]def cleanup_time(begin_time, end_time, last_time, delta_time, window=0): if (end_time < 0 or end_time > last_time): end_time = last_time if are_times_insane(begin_time, end_time): raise TimeSanityException dt = get_delta_time(begin_time, end_time, delta_time, window) return end_time, dt
[docs]def calc_time_curve(dir, n_win, temperature, begin_time, end_time, random_seed, time_ranges): """ Calculates the free energy as a function of time_ranges """ bar = init_bennett(dir, n_win, temperature, begin_time, end_time, random_seed) data = [] last_frame = len(time_ranges) - 1 for i, (start_time, stop_time, time) in enumerate(time_ranges): # Include the uncertainty every 10th frame, and the last frame nresamples = 100 if i % 10 == 0 or i == last_frame else 0 result, err, dF, results = run_bennett(bar, start_time, stop_time, nresamples) if result is not None: data.append(( time, result, dF, )) bar.close_output() if len(results) == 0: raise ValueError("Could not process data: %s" % err) return data, results
[docs]class DeltaEnergy(object): """ """
[docs] def __init__(self): self.forward = [] self.reversed = [] self.time = []
[docs]def read_dE_file(dE_fname, time_range=None): if not time_range: time_range = [ 0.0, float("inf"), ] with open(dE_fname, "r") as fh: lines = fh.read().split("\n") dE = DeltaEnergy() for line in lines: line = line.strip() if (line != "" and line[0] != "#"): time, reversed, forward = [float(e) for e in line.split()] if (time >= time_range[0] and time <= time_range[1]): dE.time.append(time) dE.reversed.append(reversed) dE.forward.append(forward) return dE
[docs]def calc_work_prob_distr(energy, energy_range=None): """ """ if (energy_range is None): e_min = min(energy) e_max = max(energy) e_span = e_max - e_min pad = e_span * 0.05 energy_range = [ e_min - pad, e_max + pad, (e_span + 2 * pad) * 0.009, ] num_bin = 1 if energy_range[2] != 0: num_bin = int((energy_range[1] - energy_range[0]) / energy_range[2]) + 1 else: # Special case if energies are all the same # have one bin and set the energy range != 0 energy_range[2] = 1 bin = [0] * num_bin num_dat = 0 for e in energy: index = int((e - energy_range[0]) / energy_range[2]) if (index >= 0 and index < num_bin): bin[index] += 1 num_dat += 1 x = [0] * num_bin for i in range(num_bin): bin[i] /= num_dat * energy_range[2] x[i] = i * energy_range[2] + energy_range[0] return x, bin
[docs]def calc_forward_reversed_work_overlap(dE0, dE1): """ """ dE1.reversed = [e * -1 for e in dE1.reversed] e_for_avg = numpy.mean(dE0.forward) e_rev_avg = numpy.mean(dE1.reversed) e_for_std = numpy.std(dE0.forward) e_rev_std = numpy.std(dE1.reversed) e_min = min(e_rev_avg - 3 * e_rev_std, e_for_avg - 3 * e_for_std) e_max = max(e_rev_avg + 3 * e_rev_std, e_for_avg + 3 * e_for_std) e_span = e_max - e_min pad = e_span * 0.05 energy_range = [ e_min - pad, e_max + pad, (e_span + 2 * pad) * 0.009, ] x, prob0 = calc_work_prob_distr(dE0.forward, energy_range) x, prob1 = calc_work_prob_distr(dE1.reversed, energy_range) def refine_energy_range(x, prob, bin_size): x_min = x[0] x_max = x[-1] inte = 0.0 for i, e in enumerate(prob): inte += e * bin_size if (inte > 0.002): if (i > 0): x_min = x[i - 1] break inte = 0.0 prob.reverse() for i, e in enumerate(prob): inte += e * bin_size if (inte > 0.002): if (i > 0): x_max = x[-i - 1] break return x_min, x_max x0_min, x0_max = refine_energy_range(x, prob0, energy_range[2]) x1_min, x1_max = refine_energy_range(x, prob1, energy_range[2]) e_min = min(x0_min, x1_min) e_max = max(x0_max, x1_max) e_span = e_max - e_min pad = e_span * 0.05 energy_range = [ e_min - pad, e_max + pad, (e_span + 2 * pad) * 0.009, ] x, prob0 = calc_work_prob_distr(dE0.forward, energy_range) x, prob1 = calc_work_prob_distr(dE1.reversed, energy_range) return x, prob0, prob1
[docs]def calc_lambda_window_overlap(dE_fname0, dE_fname1, time_range): """ """ dE0 = read_dE_file(dE_fname0, time_range) dE1 = read_dE_file(dE_fname1, time_range) if dE0.forward and dE1.forward and dE0.reversed and dE1.reversed: return calc_forward_reversed_work_overlap(dE0, dE1) raise RuntimeError("Simulation is too short, not enough data for analysis. " "The simulation should be at least %.2f ps long" % \ time_range[0])
[docs]def plot_lambda_window_overlap(dE_fname0, dE_fname1, out_fname=None, legend=None, time_range=None, filename=None, reporter=None): """ """ if not time_range: time_range = [ 0.0, float("inf"), ] x, prob0, prob1 = calc_lambda_window_overlap(dE_fname0, dE_fname1, time_range) if (out_fname): with open(out_fname, "w") as fh: for i in range(len(x)): print(x[i], prob0[i], prob1[i], file=fh) if (reporter): return reporter.plot(x, prob0, prob1, x_label="energy (kcal/mol)", legend=legend, filename=filename)
[docs]def calc_lambda_sim_matrix(num_lambda, *gibbs_dname, **kw): """ """ import schrodinger.application.desmond.gchart as gchart num_dname = len(gibbs_dname) traj_length = kw["traj_length"] if ("traj_length" in kw) else 2000.0 temperature = kw["temperature"] if ("temperature" in kw) else 300.0 mat = [] # mat[i][j], i is lambda number, j is simulation number. for i in range(num_lambda - 1): sim = [] for g0 in gibbs_dname: dE_fname0 = "gibbs.%d.dE" % i dE_fname1 = "gibbs.%d.dE" % (i + 1) util.remove_file("gibbs.0.dE") util.symlink(os.path.join(g0, dE_fname0), "gibbs.0.dE") for g1 in gibbs_dname: util.remove_file("gibbs.1.dE") util.symlink(os.path.join(g1, dE_fname1), "gibbs.1.dE") data = calc_free_energy_rtime_function(".", 2, temperature=temperature) with open("%s_%d-%s_%d-rtime" % ( g0, i, g1, i + 1, ), "w") as fh: for d in data: print(d[0], d[1].val, d[1].unc, file=fh) url = gchart.get_xy_url([d[0] for d in data], [d[1].val for d in data], err_y=[[d[1].unc for d in data]], x_label="time (ps)") print("# " + url, file=fh) print("# <img src=\"%s\" />" % (url.replace("&", "&amp;"),), file=fh) closest_time = float("inf") val_at_closest_time = float("inf") for d in data: if (abs(d[0] - traj_length) < closest_time): val_at_closest_time = d[1] sim.append(val_at_closest_time) mat.append(sim) # mat[i][j], i = 0 is the name of combination, i >= 1 is the lambda index, j is the result, which is a 2-tuple. # The first element of the 2-tuple is the free energy, the second is the probability. new_mat = [[]] for g0 in gibbs_dname: for g1 in gibbs_dname: new_mat[0].append(( g0, g1, )) for i, m in enumerate(mat): new_mat.append([]) prob = [] for e in m: prob.append(math.exp(-e.val / temperature / 1.9872E-3)) j = 0 prob_sum = [] for p in prob: if (j % num_dname == 0): prob_sum.append(0) prob_sum[-1] += p j += 1 for j, p in enumerate(prob): prob[j] = old_div(p, prob_sum[old_div(j, num_dname)]) for e, p in zip(m, prob): new_mat[-1].append(( e, p, )) mat_fname = kw.get("mat_fname") if mat_fname: with open(mat_fname, "w") as fh: print("# lambda,", end=' ', file=fh) for m in new_mat[0]: print(("%s_%s," % (m[0], m[1])), end=' ', file=fh) print(file=fh) for i, m in enumerate(new_mat[1:]): print(i, end=' ', file=fh) for e in m: print(("%s (%g)" % (str(e[0]), e[1])), end=' ', file=fh) print(file=fh) return new_mat
def _get_pathway_helper(pathway, mat): """ :param pathway: a list of tuples. Each tuple consists of three elements: structure name, free energy, and probability. """ last_step_name = pathway[-1][0] next_step_option = mat[0] all_pathway = [] next_step_mat = [next_step_option] + mat[2:] is_end = mat[2:] == [] for o, m in zip(next_step_option, mat[1]): if (last_step_name == o[0]): new_pathway = copy.copy(pathway) new_pathway.append(( o[1], m[0], m[1], )) if (is_end): all_pathway.append(new_pathway) else: all_pathway += _get_pathway_helper(new_pathway, next_step_mat) return all_pathway
[docs]def get_pathway(mat): """ """ next_step_option = mat[0] all_pathway = [] next_step_mat = [next_step_option] + mat[1:] this_step = set() for e in next_step_option: this_step.add(e[0]) for e in this_step: all_pathway += _get_pathway_helper([ ( e, measurement.Measurement(0.0, 0.0), 1.0, ), ], next_step_mat) for p in all_pathway: free_energy = 0.0 prob = 1.0 for step in p: free_energy += step[1] prob *= step[2] return all_pathway
[docs]class FreeEnergyContrib(object): """ """
[docs] def __init__(self, coulomb=None, vdw=None, bonded=None): """ """ self.fec_coulomb = coulomb self.fec_vdw = vdw self.fec_bonded = bonded
[docs]def calc_contrib(fname, cfg_fname): """ """ contrib = FreeEnergyContrib() with open(cfg_fname) as fh: sea_map = config.sea.Map(fh.read()) if ("gibbs" not in sea_map.force.term.list.val and "gibbs" not in sea_map.mdsim.plugin.list.val): raise TypeError( "Attempted to calculate free-energy components from non-FEP simulation." ) if ("gibbs" in sea_map.force.term.list.val): gibbs = sea_map.force.term.gibbs elif ("gibbs" in sea_map.mdsim.plugin.list.val): gibbs = sea_map.mdsim.plugin.gibbs if (gibbs.type.val == "alchemical"): return contrib # Checks if vdw schedule has overlap with the coulomb schedule. vdw = gibbs.weights.vdw coul = gibbs.weights.es prev = None for v, c in zip(vdw, coul): if (v.val < 1.0 and c.val > 0): raise ValueError( "Cannot decompose the free energy due to schedule overlap.") if (c.val > 0 and prev and prev.val < 1.0): raise ValueError( "Cannot decompose the free energy due to schedule overlap.") prev = v # Reads the output file to get dG. dG = [] with open(fname, "r") as fh: for line in fh: line = line.strip() if (line[0] != "#"): a = line.split() dG.append(measurement.Measurement(a[0], a[1])) # Gets the startings and endings of the VDW and Coulomb schedules. i_vdw_s = None i_vdw_e = None i_coul_s = None i_coul_e = None for i, v in enumerate(vdw): if (v.val == 0.0): i_vdw_s = i if (v.val == 1.0): i_vdw_e = i break for i, c in enumerate(coul): if (c.val == 0.0): i_coul_s = i if (c.val == 1.0): i_coul_e = i break dG_vdw = measurement.Measurement(0.0, 0.0) dG_coul = measurement.Measurement(0.0, 0.0) if (i_vdw_s is not None and i_vdw_e is not None): for i in range(i_vdw_s, i_vdw_e): dG_vdw += dG[i] if (i_coul_s is not None and i_coul_e is not None): for i in range(i_coul_s, i_coul_e): dG_coul += dG[i] contrib.fec_vdw = dG_vdw contrib.fec_coulomb = dG_coul contrib.fec_bonded = measurement.Measurement(0.0, 0.0) return contrib
[docs]def correct_restr(egout0, egout1, fname_out): """ """ ene0 = get_energy_table(egout0, ["posre"])[1]["posre"] ene1 = get_energy_table(egout1, ["posre"])[1]["posre"] len0 = len(ene0) len1 = len(ene1) if (len0 == 0 or len1 == 0): return None mean0, error0, stddev0 = get_mean(ene0) mean1, error1, stddev1 = get_mean(ene1) a0 = measurement.Measurement(mean0, error0) a1 = measurement.Measurement(mean1, error1) with open(fname_out, "w") as fh: print("Restraint energy for lambda=0 (mean, error): %f, %f" % ( mean0, error0, ), file=fh) print("Restraint energy for lambda=1 (mean, error): %f, %f" % ( mean1, error1, ), file=fh) print("Correction due to restraints: %s" % str(a0 - a1), file=fh) return a0 - a1
[docs]def long_range_dispersion_energy(r_cut, c6, rho): """ r_cut: cutoff radius (Angstrom). c6: average dispersion coefficient (kcal/mol * Angstrom**6). rho: number density (1/ Angstrom**3) """ return -4.0 / 3.0 * 3.1415926 * rho * c6 / (r_cut * r_cut * r_cut)
[docs]def get_field_from_log(field, fname): """ """ with open(fname, "r") as fh: content = fh.read() PATTERN = re.compile(field + " *= *([.0-9]+)") m = PATTERN.search(content) if m: value = m.group(1) return float(value)
[docs]def get_number_density_from_cms(model): """ Returns a tuple of elements as follows: 1. the number density in the unit of 1 / Angstrom**3 2. number of atoms in the system 3. volume of the system """ num_atom = model.fsys_ct.atom_total volume = cms.get_boxvolume(model.box) return ( old_div(num_atom, volume), num_atom, volume, )
[docs]def get_average_box_volume(fname): """ 'fname' must be a `*_simbox.dat` file. """ with open(fname, "r") as fh: lines = fh.readlines() volume_sum = 0.0 num_data = 0 for line in lines: line = line.strip() if ("" != line and '#' != line[0]): line = line.replace("Chemical time:", " ") line = line.replace("ps, Box vectors:", " ") box = [float(e) for e in line.split()] volume_sum += cms.get_boxvolume(box[1:]) num_data += 1 return old_div(volume_sum, num_data)
[docs]def calc_long_range_dispersion_energy(model, atom_list, log_fname=None, simbox_fname=None, cfg_fname=None, r_cut=-1, average_coefficient=-1): """ """ if (log_fname and os.path.isfile(log_fname)): if r_cut == -1: r_cut = get_field_from_log("r_cut", log_fname) if average_coefficient == -1: average_coefficient = get_field_from_log("average_dispersion", log_fname) rho, num_atom, volume = get_number_density_from_cms(model) if (average_coefficient < 0): average_coefficient = cms.calc_average_vdw_coeff(model.comp_ct) if (r_cut is None or r_cut < 0): if (cfg_fname and os.path.isfile(cfg_fname)): with open(cfg_fname) as fh: sea_map = config.sea.Map(fh.read()) r_cut = sea_map.cutoff_radius.val else: raise ValueError( "Lack of information to determine the cutoff radius.") if (simbox_fname and os.path.isfile(simbox_fname)): average_volume = get_average_box_volume(simbox_fname) rho = old_div(num_atom, average_volume) vdw = model.get_vdw() energy = 0.0 for atom in atom_list: i_atom = int(atom) atom_c6 = vdw[i_atom].c6() #print "atom %d: type %s, sigma %f, epsilon %f, c6 %f" % \ # (i_atom, vdw[i_atom].atom_type[0], vdw[i_atom].c[0], vdw[i_atom].c[1], atom_c6,) mixed_c6 = math.sqrt(atom_c6 * average_coefficient) energy += long_range_dispersion_energy(r_cut, mixed_c6, rho) return energy, r_cut, average_coefficient, rho
[docs]def calc_free_energy_correction_due_to_restraint(r, fc, temperature) -> float: """ :param r: Cross-link restraint to calculate the free energy correction for :param fc: Three force constants for the stretch, the angle, and the torsion restraints, respectively. """ from schrodinger.application.desmond.packages.restraint import \ CrossLinkRestraint assert (isinstance(r, CrossLinkRestraint) and 3 == len(fc)) stretch = Restraint([r.A, r.a], r.Aa[0], fc[0]) angle0 = Restraint([r.B, r.A, r.a], r.BAa[0], fc[1]) angle1 = Restraint([r.A, r.a, r.b], r.Aab[0], fc[1]) torsion0 = Restraint([r.B, r.A, r.a, r.b], r.BAab[0], fc[2]) torsion1 = Restraint([r.A, r.a, r.b, r.c], r.Aabc[0], fc[2]) torsion2 = Restraint([r.C, r.B, r.A, r.a], r.CBAa[0], fc[2]) return calc_free_energy_for_abfe_cross_link_xu( [stretch, angle0, angle1, torsion0, torsion1, torsion2], temperature)
[docs]def calc_free_energy(dir, last_time: float, n_win: int, temperature: float, bennett_options: Dict, random_seed: int) -> Dict: """ Return forward, reverse and slide energies. """ output = {} func_dict = { 'forward_time': calc_free_energy_time_function, 'reversed_time': calc_free_energy_rtime_function, 'sliding_time': calc_free_energy_stime_function } for attr_str, energy_func in func_dict.items(): attr = bennett_options[attr_str] # copy data from attr since it (attr) might have more variables. kwargs = dict(temperature=temperature, begin_time=attr['begin'], end_time=attr['end'], delta_time=attr['dt'], random_seed=random_seed) if attr_str == 'sliding_time': kwargs['window'] = attr['window'] data, results = energy_func(dir, last_time, n_win, **kwargs) output[attr_str] = (data, results) return output
[docs]def plot_convergence(data, dG_fname, dF_fname_pattern, x_label, dF_color, dG_color="black", reporter=None) -> Dict[str, Union[List[str], str]]: """process the `data` and write png files. Return a dictionary with format {'url': url, 'dF': df, 'dG': dG_fname} """ url = [] df = [] if len(data) > 2: x, y, err = [], [], [] with open(dG_fname, "w") as fh: for d in data: print(d[0], '%.4f' % d[1].val, '%.4f' % d[1].unc, file=fh) x.append(d[0]) y.append(d[1].val) err.append(d[1].unc) if x[-1] >= 10000: x = [e * 0.001 for e in x] x_label_ = x_label + " (ns)" else: x_label_ = x_label + " (ps)" if reporter: url.append( reporter.plot(x, y, err_y=[err], x_label=x_label_, y_label="dG", color=[dG_color], filename=dG_fname + ".png")) for i in range(len(data[0][2])): fname = dF_fname_pattern % (i, i + 1) df.append(fname) x, y, err = [], [], [] with open(fname, "w") as fh: for d in data: print(d[0], '%.4f' % d[2][i].val, '%.4f' % d[2][i].unc, file=fh) x.append(d[0]) y.append(d[2][i].val) err.append(d[2][i].unc) if x[-1] >= 10000: x = [e * 0.001 for e in x] x_label_ = x_label + " (ns)" else: x_label_ = x_label + " (ps)" if reporter: url.append( reporter.plot(x, y, err_y=[err], x_label=x_label_, y_label="dF", legend=["%d_%d" % (i, i + 1)], color=[dF_color], filename=fname + ".png")) return {'url': url, 'dF': df, 'dG': dG_fname}