Source code for schrodinger.application.desmond.enhsamp

import os
import re
import sys
from functools import reduce
from past.utils import old_div

import numpy

from . import antlr3
from .enhanced_sampling.FcnTypes import getFcnSigs
from .enhanced_sampling.mexpLexer import mexpLexer
from .enhanced_sampling.mexpParser import ADDOP
from .enhanced_sampling.mexpParser import BIND
from .enhanced_sampling.mexpParser import BLOCK
from .enhanced_sampling.mexpParser import CUTOFF
from .enhanced_sampling.mexpParser import DECL_META
from .enhanced_sampling.mexpParser import DECL_OUTPUT
from .enhanced_sampling.mexpParser import DIM
from .enhanced_sampling.mexpParser import ELEM
from .enhanced_sampling.mexpParser import FIRST
from .enhanced_sampling.mexpParser import HEADER
from .enhanced_sampling.mexpParser import IF
from .enhanced_sampling.mexpParser import INITKER
from .enhanced_sampling.mexpParser import INTERVAL
from .enhanced_sampling.mexpParser import ITER
from .enhanced_sampling.mexpParser import LIT
from .enhanced_sampling.mexpParser import NAME
from .enhanced_sampling.mexpParser import SERIES
from .enhanced_sampling.mexpParser import STATIC
from .enhanced_sampling.mexpParser import STRING
from .enhanced_sampling.mexpParser import SUBTROP
from .enhanced_sampling.mexpParser import VAR
from .enhanced_sampling.mexpParser import mexpParser

# FIXME add line numbers to errors


[docs]def showtype(i): if isinstance(type(i), str): return 'string' return 'length-' + str(i) + ' array'
[docs]class Node(object):
[docs] def __init__(self, env, children): self.children = children self.env = env
#def __getitem__(self, i): # self.children[i]
[docs] def resolve_atomsel(self, aslobj, gids): self.children = [c.resolve_atomsel(aslobj, gids) for c in self.children] return self
[docs] def constant_fold(self): # constant folding routine may assume that the program is well-typed self.children = [c.constant_fold() for c in self.children] return self
# get_type validates the type correctness of the expression and # returns the type of the node and validates # get_type also ensures that all variables can be uniquely resolved # in the given scope
[docs] def get_type(self): raise ValueError('Internal error. Function not defined')
[docs]class Lit(Node):
[docs] def __init__(self, env, value): # should validate input Node.__init__(self, env, []) self.value = value
[docs] def get_type(self, env): return len(self.value)
def __str__(self): if len(self.value) == 1: return repr(self.value[0]) else: return '[literal %s]' % ' '.join([repr(v) for v in self.value])
[docs]class String(Node):
[docs] def __init__(self, env, value): # should validate input Node.__init__(self, env, []) self.value = value
[docs] def get_type(self, env): return 'string'
# the following str assumes that there are no characters than need escaping def __str__(self): return '"' + self.value + '"'
[docs]class Var(Node):
[docs] def __init__(self, env, name): Node.__init__(self, env, []) self.name = name
[docs] def get_type(self, env): try: return env.binds[self.name] except KeyError as e: str = 'Variable %s unknown' % self.name raise KeyError(str)
def __str__(self): return '$' + self.name
[docs]class Bind(Node):
[docs] def __init__(self, env, name, value): Node.__init__(self, env, [value]) self.name = name
[docs] def get_type(self, env): return self.children[0].get_type(env)
def __str__(self): return '[$%s %s]' % (str(self.name), str(self.children[0]))
[docs]class FcnCall(Node):
[docs] def __init__(self, env, name, children): Node.__init__(self, env, children) self.name = name
def __str__(self): arglist = ' '.join([str(n) for n in self.children]) return '[%s %s]' % (self.name, arglist)
[docs] def resolve_atomsel(self, aslobj, gids): if self.name == 'atomsel': if len(self.children) != 1: raise TypeError('atomsel takes one argument') s = self.children[0] if type(s) is String: env = Env() #l = aslobj.atomsel.atomsel(s.value).get('index') sel = aslobj.atomsel(s.value) gids.update(sel) cs = [Lit(env, [float(i)]) for i in sel] return FcnCall(env, 'array', cs) else: raise ValueError('Argument to atomsel must have type string') else: return Node.resolve_atomsel(self, aslobj, gids)
[docs] def constant_fold(self): self.children = [c.constant_fold() for c in self.children] if not all([isinstance(c, Lit) for c in self.children]): return self vals = [c.value for c in self.children] if self.name == 'array': elems = list() for v in vals: elems.extend(v) return Lit(Env(), elems) # this simple folding is primarily intended to handle cases where the # parser interprets "-5.0" as [* -1 5.0] def bin_thread(f, arg1, arg2): if len(arg1) == 1: x = arg1[0] return [f(x, y) for y in arg2] elif len(arg2) == 1: y = arg2[0] return [f(x, y) for x in arg1] elif len(arg1) == len(arg2): return [f(x, y) for x, y in zip(arg1, arg2)] else: raise RuntimeError( 'Internal error. Invalid type on binary operation') if self.name == '+' and len(vals) == 2: return Lit(Env(), bin_thread(lambda x, y: x + y, vals[0], vals[1])) elif self.name == '*' and len(vals) == 2: return Lit(Env(), bin_thread(lambda x, y: x * y, vals[0], vals[1])) elif self.name == '-' and len(vals) == 2: return Lit(Env(), bin_thread(lambda x, y: x - y, vals[0], vals[1])) elif self.name == '/' and len(vals) == 2: return Lit(Env(), bin_thread(lambda x, y: old_div(x, y), vals[0], vals[1])) else: return self
[docs] def get_type(self, env): if self.name == 'load': if len(self.children) != 1: raise TypeError('Wrong number of arguments to load') arg = self.children[0] t = arg.get_type(env) if not isinstance(t, str): raise TypeError('Must pass string to load, not ' + str(t)) if arg.value not in env.statics: raise ValueError('unkown variable %s' % arg.value) return env.statics[arg.value] elif self.name == 'store': if len(self.children) != 2: raise TypeError('Wrong number of arguments to store') arg0 = self.children[0] t = arg0.get_type(env) if not isinstance(t, str): raise TypeError('Must pass string as argument 1 of store, not ' \ + showtype(t)) if arg0.value not in env.statics: raise ValueError('variable %s is unknown in store' % arg0.value) tdes = env.statics[arg0.value] t = self.children[1].get_type(env) if tdes != t: raise TypeError('attempt to store %s in %s, but %s has type %s' % \ (showtype(t), arg0.value, arg0.value, showtype(tdes))) return env.statics[arg0.value] else: child_types = [c.get_type(env) for c in self.children] # check if my function name exists t = env.sigs[self.name].check(child_types) if self.name == 'meta' and type(self.children[0]) is Lit \ and len(self.children[0].value) == 1: mid = int(round(self.children[0].value[0])) if mid < 0 or mid >= len(env.metas): raise ValueError( 'metadynamics accumulator id outsides range of accumulators' ) d = env.metas[mid].dim if d != child_types[2]: raise TypeError(('metadynamics accumulator %i has dimension %i' + \ ' but was passed a %s collective variable') \ % (mid, d, showtype(child_types[2]))) return t
[docs]class Iter(Node):
[docs] def __init__(self, env, name, lb, ub): Node.__init__(self, env, [lb, ub]) self.name = name
[docs] def get_type(self, env): tl = self.children[0].get_type(env) tu = self.children[1].get_type(env) errmsg = '%s bound of iterator %s must be a length-1 array but is a %s' if tl != 1: raise TypeError(errmsg % ('Lower', self.name, showtype(tl))) if tu != 1: raise TypeError(errmsg % ('Upper', self.name, showtype(tu))) return 1
def __str__(self): return '[$%s %s %s]' % (self.name, self.children[0], self.children[1])
[docs]class Let(Node):
[docs] def __init__(self, env, binds, value): # for each bind, if it is not an assignment, make a gensym Node.__init__(self, env, binds + [value]) self.binds = binds for i in range(len(self.binds)): if not type(self.binds[i]) is Bind: self.binds[i] = Bind(env, env.gensym(), self.binds[i])
[docs] def get_type(self, env): env.binds.enter_scope() for b in self.binds: env.binds.add_var(b.name, b.get_type(env)) t = self.children[-1].get_type(env) env.binds.leave_scope() return t
def __str__(self): bindlist = ' '.join(map(str, self.binds)) return '[let [%s] %s]' % (bindlist, str(self.children[-1]))
[docs]class Series(Node):
[docs] def __init__(self, env, iters, value): Node.__init__(self, env, iters + [value]) self.iters = iters
[docs] def get_type(self, env): env.binds.enter_scope() for b in self.iters: env.binds.add_var(b.name, b.get_type(env)) t = self.children[-1].get_type(env) env.binds.leave_scope() return t
def __str__(self): bindlist = ' '.join(map(str, self.iters)) return '[series [%s] %s]' % (bindlist, str(self.children[-1]))
[docs]class If(Node):
[docs] def __init__(self, env, cond, then_case, else_case): Node.__init__(self, env, [cond, then_case, else_case])
[docs] def get_type(self, env): tc = self.children[0].get_type(env) tt = self.children[1].get_type(env) te = self.children[2].get_type(env) if tc != 1: raise TypeError( 'Condition of if expression must have type 1 but has type %s' % showtype(tc)) if tt != te: raise TypeError( 'The branches of the if expression must have the same types but ' \ 'the types are %s and %s' % (showtype(tt), showtype(te))) return tt
def __str__(self): return '[if %s %s %s]' % (str(self.children[0]), \ str(self.children[1]), \ str(self.children[2]))
[docs]class Meta(object):
[docs] def __init__(self, dim, cutoff, first, interval, output, initial): self.dim = dim self.cutoff = cutoff self.first = first self.interval = interval self.output = output self.initial = initial
def __str__(self): return '{dimension=%i cutoff=%s first=%s interval=%s name="%s" initial_kernels="%s" accumulate_on_the_device="%s"}' % \ (self.dim, repr(self.cutoff), repr(self.first), repr(self.interval), self.output, self.initial, 'true')
# bindings represents a scoped mapping from names to Node's
[docs]class binding(object):
[docs] def __init__(self): self.bs = [{}]
[docs] def enter_scope(self): self.bs = [{}] + self.bs
[docs] def leave_scope(self): self.bs = self.bs[1:]
[docs] def add_var(self, name, tp): if name in self.bs[0]: raise ValueError('%s declared twice in the same scope' % name) self.bs[0][name] = tp
def __getitem__(self, it): #print str(self.bs) for d in self.bs: if it in d: return d[it] # if we reach this statement, it has not been declared raise KeyError def __str__(self): s = '' for d in self.bs: s = s + str(d) return s
[docs]class Env(object):
[docs] def __init__(self): self.statics = {} self.metas = [] self.gensym_cnt = 0 self.output = None self.binds = binding() self.sigs = getFcnSigs()
[docs] def gensym(self): i = self.gensym_cnt self.gensym_cnt = i + 1 return 'gensym' + str(i)
[docs] def add_static(self, nm, type): if nm in self.statics: raise ValueError('Variable (%s) bound twice in same scope' % nm) if type < 0: raise ValueError("Type of static variable %s cannot be negative" % nm) self.statics[nm] = type
[docs] def add_output(self, nm, first, interval): if not (self.output is None): raise "Output cannot be declared twice" self.output = (nm, first, interval)
def __str__(self): ret = '' d = self.statics ret = ret + 'metadynamics_accumulators=[%s] ' % \ (' '.join(map(str, self.metas))) ret = ret + \ "storage={%s} " % ' '.join(['%s=%i' % (k, d[k]) for k in sorted(d)]) if self.output is not None: name, first, interval = self.output else: name, first, interval = ("", 0.0, 0.0) ret = ret + 'name="%s" first=%s interval=%s' \ % (name, repr(first), repr(interval)) return ret
[docs]def headerToEnv(header): if header.getToken().getType() != HEADER: raise ValueError('Internal error. Function not defined') env = Env() for c in header.children: tok = c.getToken() t = tok.getType() if t == STATIC: tokens = [x.getToken().getText() for x in c.children] env.add_static(tokens[0], int(tokens[1])) elif t == DECL_META: d = {} for attr in c.children: attr_tok = attr.getToken().getType() if attr_tok in d: raise "Invalid meta declaration -- redeclared parameter" d[attr_tok] = ''.join( [c.getToken().getText() for c in attr.children]) if not (CUTOFF in d): d[CUTOFF] = "0.0" m = Meta(int(d[DIM]), \ float(d[CUTOFF]), \ float(d[FIRST]), \ float(d[INTERVAL]), \ d[NAME][1:len(d[NAME]) - 1], \ d[INITKER][1:len(d[INITKER]) - 1]) env.metas.append(m) elif t == DECL_OUTPUT: d = {} for attr in c.children: attr_tok = attr.getToken().getType() if attr_tok in d: raise "Invalid output declaration" # to handle negation d[attr_tok] = ''.join( [c.getToken().getText() for c in attr.children]) env.add_output(d[NAME][1:len(d[NAME]) - 1], \ float(d[FIRST]), \ float(d[INTERVAL])) else: raise ValueError('Internal error. Function not defined') return env
[docs]def bodyToNode(tree, env): tok = tree.getToken() if tok is None: sys.stderr.write('failed to parse m-expression\n') exit(1) t = tok.getType() cs = [] for c in tree.children: cs.append(bodyToNode(c, env)) if t == BLOCK: # FIXME must handle this better for gensym if len(cs) == 0: EmptyBlock = "Error: empty block" raise EmptyBlock if len(cs) == 1: return cs[0] return Let(env, cs[0:len(cs) - 1], cs[len(cs) - 1]) elif t == IF: return If(env, cs[0], cs[1], cs[2]) elif t == BIND: return Bind(env, cs[0].name, cs[1]) elif t == ITER: return Iter(env, cs[0].name, cs[1], cs[2]) elif t == SERIES: return Series(env, cs[0:len(cs) - 1], cs[len(cs) - 1]) elif t == ELEM: return reduce(lambda l, r: FcnCall(env, 'elem', [l, r]), cs[1:], cs[0]) elif t == VAR: nm = tree.children[0].getText() if nm in env.statics: return FcnCall(env, 'load', [String(env, nm)]) return Var(env, nm) elif t == LIT: val = tok.getText() return Lit(env, [float(val)]) elif t == STRING: val = tok.getText() # includes quote marks return String(env, val[1:len(val) - 1]) elif t == ADDOP and len(cs) == 1: # prefix + return cs[0] elif t == SUBTROP and len(cs) == 1: # prefix - return FcnCall(env, '*', [Lit(env, [-1.0]), cs[0]]) else: # due to some ambiguities in the processing, 'store' must be special # cased if tok.getText() == 'store' and len(cs) >= 1 and isinstance(cs[0], FcnCall) \ and cs[0].name == 'load': return FcnCall(env, 'store', [cs[0].children[0]] + cs[1:]) return FcnCall(env, tok.getText(), cs)
[docs]def parse_indices(indices_string): """ This function parse indices string and return unique indices in ascending order. Note that it only supports range selection (using '-') and individual index. ' ' and ',' is separator in ASL. '7 3 4, 2- ,, 7' is equivalent to '3, 7 4, 2-7' evaluate_asl and parse_indices does not agree on '-7, -3- ,,,4'. The former gives [5, 6, 7]. This is not consistent with the definition of ASL. """ def is_integer(t): ret = False try: int(t) ret = True except ValueError: pass return ret # ' ' and ',' are both considered as white space. Replace ',' with ' ' to make # it easier when calling split function s = indices_string.replace(',', ' ') s = s.strip() tokens = [] for tail in s.split(): while tail: head, dash, tail = tail.partition('-') head = head.strip() if head: if is_integer(head): tokens.append(head) else: raise RuntimeError("Failed to prase indices: %s" % indices_string) if dash: tokens.append(dash) if not dash or not tail: break i = 0 stack = [] indices = [] while i < len(tokens): tok = tokens[i] if tok == '-': i += 1 if i < len(tokens): tok = tokens[i] if is_integer(tok): try: # pop up from the stack when encounting '-' begin = stack.pop() indices.extend(list(range(int(begin), int(tok) + 1))) indices.extend([int(e) for e in stack]) stack = [] except IndexError: pass else: raise RuntimeError("Failed to prase indices: %s" % indices_string) else: raise RuntimeError("Failed to prase indices: %s" % indices_string) else: stack.append(tok) i += 1 # prepare unique indices in ascending order to match ASL's behavior indices.extend([int(e) for e in stack]) indices = set(indices) indices = list(indices) return indices
[docs]class ASLObject: _index_only_pattern = re.compile(r'atom.\s+(.*)')
[docs] def __init__(self, model): self._cms = model
[docs] def atomsel(self, asl_expr): if self._cms: indices = numpy.asarray(self._cms.select_atom(str(asl_expr))) gids = self._cms.gid(indices) else: # FIXME # .cms is needed in order to translate front end config file to # backend config file. Unfortunately, to restart the simulation # from checkpoint file, one does not have a .cms file. # Below is a workaround to get gids without .cms file. It will only # work if gid = atid - 1 is true for all atoms. In other words, it # will fail on restarting FEP with metadynamics or the reaction # coordinate happen to involve molecules with virtual site. m = self._index_only_pattern.match(str(asl_expr)) if m: atids = parse_indices(m.group(1)) atids = numpy.array(atids) - 1 gids = list(atids) else: raise RuntimeError( "Failed to get gid from asl ('%s') without structure." % asl_expr) return gids
[docs]def procText(text): lex = mexpLexer(antlr3.StringStream(text)) tokStream = antlr3.CommonTokenStream(lex) pt = mexpParser(tokStream).prog().tree #lex.reset() #print 'Lexing:' #for l in lex: # print l.getText() #print #lex.reset() #print 'Parsing:' #for c in pt.children: # print c.toStringTree() #print env = headerToEnv(pt.children[0]) action = bodyToNode(pt.children[1], env) return (action, env)
[docs]def resolve_atomsel(body, model): # VMD prints its banner to stdout but the banner needs to go to stderr sys.stdout.flush() sys.stderr.flush() sout = os.dup(sys.stdout.fileno()) os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) #print 'using VMD to resolve atom selections ...\n' #import vmd #struct_type = struct_name.strip().split('.')[-1] #print 'using structure type %s' % struct_type #vmd.molecule.load(struct_type, struct_name) #print '' aslobj = ASLObject(model) sys.stdout.flush() sys.stderr.flush() os.dup2(sout, sys.stdout.fileno()) gids = set() newbody = body.resolve_atomsel(aslobj, gids) return (newbody, gids)
[docs]def parseStr(system, mexp): """ return partial frontend config file that contains enhanced_sampling plugin. """ action, env = procText(mexp) action_resolved, gids = resolve_atomsel(action, system) t = action_resolved.get_type(env) # note that constant folding assumes a well-typed program action_fold = action_resolved.constant_fold() gid_text = ' '.join(map(str, sorted(gids))) if type(t) is str or t != 1: raise TypeError( 'Potential must be a length-1 array, but is currently ' + showtype(t)) return '{type=enhanced_sampling gids=[%s] sexp=%s %s}' \ % (gid_text, str(action_fold), str(env))
[docs]def parse_mexpr(system, mexp): """ return partial backend config file that contains enhanced_sampling plugin. """ cfg = 'force.term{list[+]=ES ES=%s}' % parseStr(system, mexp) return cfg