Source code for schrodinger.utils.sea.evalor

"""
Module for parameter validation. See `schrodinger.utils.sea` for more details.
Copyright Schrodinger, LLC. All rights reserved.
"""
import inspect
import os
import re
from copy import deepcopy

from .common import boolean
from .common import debug_print
from .common import is_equal
from .sea import Atom
from .sea import List
from .sea import Map


[docs]class Evalor: """ This is the evaluator class for checking validity of parameters. """ __slots__ = [ "_map", "_err_break", "_err", "_unchecked_map", ]
[docs] def __init__(self, map, err_break="\n\n"): """ :param map: 'map' contains all parameters to be checked. """ self._map = map self._err_break = err_break self._err = "" self._unchecked_map = []
def __call__(self, arg): """ :param arg: The validation criteria. """ return _eval(self._map, arg) @property def err(self): return self._err
[docs] def is_ok(self): """ Returns true if there is no error and unchecked maps. """ return (not self._err and not self._unchecked_map)
[docs] def record_error(self, mapname=None, err=""): """ Records the error. :param mapname: The name of the checked parameter. :param err: The error message. """ debug_print("ERROR\n%s" % err) if (mapname is not None): self._err += mapname[1:] + ": " self._err += err + self._err_break
@property def unchecked_map(self): """ Returns a string that tell which parameters have not been checked. """ s = "" for k in self._unchecked_map: s += k[1:] + " " return s
[docs] def copy_from(self, ev): """ Makes a copy from 'ev'. :param ev: A 'Evalor' object. """ self._map = ev._map self._err = ev._err self._unchecked_map = ev._unchecked_map
[docs]def check_map(map, valid, ev, tag=set()): # noqa: M511 """ Checks the validity of a map. """ if (not map.has_tag(tag)): debug_print("(none is tagged with: %s)" % (", ".join(tag))) return map = map.sval _check_map(map, valid, ev, "", tag) debug_print("\nUnchecked maps:") if (ev._unchecked_map == []): debug_print("(none)") else: debug_print(ev.unchecked_map) debug_print("\nError summary:") if (ev._err == ""): debug_print("(none)") return else: debug_print(ev._err) return ev._err
def __op_mul(map, arg): """ Evaluates the "multiplication" expression and returns product of the arg[0], arg[1], arg[3], ... :param arg: The 'arg' should be a 'sea.List' object that contains two or more elements. :param map: The original map that the elements in the 'arg' refer. """ prod = 1.0 for e in arg: prod *= _eval(map, e) return prod def __op_eq(map, arg): """ Evaluates the "equal" expression and returns True the arg[0] and arg[1] are equal or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ a = _eval(map, arg[0]) b = _eval(map, arg[1]) if (isinstance(a, float) or isinstance(b, float)): return is_equal(a, b) return a == b def __op_lt(map, arg): """ Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) < _eval(map, arg[1]) def __op_le(map, arg): """ Evaluates the "less or equal" expression and returns True the arg[0] is less than or equal to arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) <= _eval(map, arg[1]) def __op_gt(map, arg): """ Evaluates the "greater than" expression and returns True the arg[0] is greater than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) > _eval(map, arg[1]) def __op_ge(map, arg): """ Evaluates the "greater or equal" expression and returns True the arg[0] is greater than or equal to arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) >= _eval(map, arg[1]) def __op_and(map, arg): """ Evaluates the "logic and" expression and returns True if both arg[0] and arg[1] are true. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) and _eval(map, arg[1]) def __op_or(map, arg): """ Evaluates the "logic or" expression and returns True if either arg[0] or arg[1] is true. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) or _eval(map, arg[1]) def __op_not(map, arg): """ Evaluates the "logic not" expression and returns True if arg[0] is false or False if arg[0] is true. :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_not' function expects 1 argument, but there are %d" % len(arg)) return not _eval(map, arg[0]) def __op_at(map, arg): """ Evaluates the "at" expression and returns the referenced value. :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_at' function expects 1 argument, but there are %d" % len(arg)) k = map[_eval(map, arg[0])] try: return k.val except AttributeError: return k def __op_minus(map, arg): """ Evaluates the "minus" expression and returns arithmatic result (the difference between two values, or the negative value). :param arg: The 'arg' should be a 'sea.List' object that contains at most two elements. More than two elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ num_arg = len(arg) if (num_arg > 2): raise ValueError( "'__op_minus' function expects 1 or 2 arguments, but there are %d" % len(arg)) if (num_arg == 1): return -_eval(map, arg[0]) else: return _eval(map, arg[0]) - _eval(map, arg[1]) def __op_cat(map, arg): """ Contatenate two strings and returns the result. :param arg: The 'arg' should be a 'sea.List' object that contains at least 1 elements. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) < 1): raise ValueError( "'__op_cat' function expects at least 1 argument, but there is none" ) ret = "" for a in arg: ret += str(_eval(map, a)) return ret def __op_sizeof(map, arg): """ Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_sizeof' function expects 1 argument, but there are %d" % len(arg)) return len(_eval(map, arg[0]))
[docs]def is_powerof2(x): """ Returns True if 'x' is a power of 2, or False otherwise. """ return not (x & (x - 1))
def _regex_match(pattern): """ """ return lambda s: re.match(pattern, s) def _xchk_power2(map, valid, ev, prefix): """ This is an external checker. It checks whether an integer value is power of 2 or not. :param map: 'map' contains the value to be checked. Use 'map.val' to get the value. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (not is_powerof2(val)): debug_print("Error:\nValue %d is not an integer of power of 2" % val) ev.record_error(prefix, "Value %d is not an integer of power of 2" % val) else: debug_print("OK - value is an integer of powere of 2") def _xchk_file_exists(map, valid, ev, prefix): """ This is an external checker. It checks whether a file (not a dir) exists. :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (val != "" and not os.path.isfile(val)): debug_print("Error:\nFile not found: %s" % val) ev.record_error(prefix, "File not found: %s" % val) else: debug_print("OK - file exists") def _xchk_dir_exists(map, valid, ev, prefix): """ This is an external checker. It checks whether a dir (not a file) exists. :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (val != "" and not os.path.isdir(val)): debug_print("Error:\nDirectory not found: %s" % val) ev.record_error(prefix, "Directory not found: %s" % val) else: debug_print("OK - Directory exists") def _eval(map, arg): """ Evaluates the expression and returns the results. :param arg: 'arg' can be either a 'sea.List' object or a 'sea.Atom' object, representing a prefix expression. :param map: The original map that the elements in the 'arg' refer. """ if (isinstance(arg, List)): val0 = _eval(map, arg[0]) if (isinstance(val0, str)): val0 = val0.strip() if (val0 in __OP): a = arg[1:] return __OP[val0](map, arg[1:]) return [_eval(map, e) for e in arg] else: val = arg.val if (val in ['-', '@', '']): return val try: if (val[0] == "@"): k = map[val[1:]] try: return k.val except AttributeError: return k except TypeError: pass return val __OP = { "*": __op_mul, "==": __op_eq, "<": __op_lt, "<=": __op_le, ">": __op_gt, ">=": __op_gt, "&&": __op_and, "||": __op_or, "!": __op_not, "@": __op_at, "-": __op_minus, "cat": __op_cat, "sizeof": __op_sizeof, } __TYPE = { "str": str, "str1": ( str, [1, 1000000000], ), "float": float, "float+": ( float, [0, float("inf")], ), "float-": ( float, [float("-inf"), 0], ), "float0_1": ( float, [0, 1.0], ), "int": int, "int0": ( int, [0, 1000000000], ), "int1": ( int, [1, 1000000000], ), "bool": boolean, "bool0": ( boolean, [False], ), "bool1": ( boolean, [True], ), "enum": str, "list": list, "none": None, "regex": _regex_match, } __CONVERTIBLE_TO = { int: [float, str], float: [str], } __xcheck = { "power2": _xchk_power2, "file_exists": _xchk_file_exists, "dir_exists": _xchk_dir_exists, }
[docs]def reg_xcheck(name, func): """ Registers external checker. :param name: Name of the checker. :param func: Callable object that checks validity of a parameter. For interface requirement, see '_xchk_power2', or '_xchk_file_exists', or '_xchk_dir_exists' for example. """ __xcheck[name] = func
def _match(map, valid, ev, prefix, tag): """ Finds the best match. """ kk = map vv = valid ev_list = [] for vv_ in vv: try: _if = ev(vv._if) except AttributeError: pass else: debug_print("_if: {} = {}".format( str(vv._if), _if, ), False) if (_if): debug_print("True") else: debug_print("False - Skip checking the whole map.") return ev_ = deepcopy(ev) _check_map(kk, vv_, ev_, prefix) ev_list.append(ev_) if (ev_list != []): # Tries to find the best match. candidate = [ ev_list[0], ] least = len(candidate[0]._unchecked_map) for ev_ in ev_list[1:]: num = len(ev_._unchecked_map) if (num < least): candidate = [ ev_, ] least = num elif (num == least): candidate.append(ev_) best_ev = [] for ev_ in candidate: if (ev_._err == ev._err): best_ev.append(ev_) if (best_ev == []): best_ev = candidate candidate = best_ev best_ev = candidate[0] least = best_ev._err.count("Wrong type:") if (len(candidate) > 1): for ev_ in candidate[1:]: num = ev_._err.count("Wrong type:") if (num < least): best_ev = ev_ least = num ev.copy_from(best_ev) def _check_atom(atom, valid, ev, prefix): """ Checks the validity of atom. """ rr = None # Range # type debug_print(prefix + ":") debug_print(" checking its type...", False) try: t = ev(valid.type) if (t.startswith("regex:")): tt = __TYPE["regex"](t[6:]) else: tt = __TYPE[t] if (isinstance(tt, tuple)): tt, rr = tt[0], tt[1] except AttributeError: ev.record_error( prefix, "Wrong type: expecting a composite parameter, but got an atom") return except KeyError: ev.record_error( prefix, "Wrong type: %s. 'type' is likely a parameter than a description." % t) return atom_val = atom.val if (atom_val is None): if (tt is None): debug_print("OK - value None is acceptable") else: ev.record_error(prefix, "Wrong value: expecting %s, but got None" % str(tt)) return if (atom._type == str and inspect.isfunction(tt) and tt != boolean): if (tt(atom_val)): debug_print("OK - {} matches the pattern: {}".format( atom_val, t[6:])) else: ev.record_error( prefix, "Wrong type: expecting a string matching {}, but got {}".format( t[6:], atom_val, )) return elif (atom._type != tt and (atom._type not in __CONVERTIBLE_TO or tt not in __CONVERTIBLE_TO[atom._type])): ev.record_error( prefix, "Wrong type: expecting {}, but got {}".format( "boolean" if tt == boolean else str(tt), str(atom._type), )) return else: debug_print("OK - %s" % t) # range debug_print(" checking its range...", False) try: if (rr is None): rr = ev(valid.range) except AttributeError: debug_print("N/A") else: if (t == "enum" or tt == boolean): if (atom_val not in rr): ev.record_error( prefix, "Wrong value: should be one of {}, but got '{}'".format( str(rr), str(atom_val), )) else: debug_print("OK - '{}' is one of {}".format( str(atom_val), str(rr), )) elif (tt == str): if (atom._type != tt): atom_val = str(atom_val) length = len(atom_val) if (length > int(rr[1])): ev.record_error(prefix, "String is too long (%d char's)" % length) elif (length < int(rr[0])): ev.record_error( prefix, "String is too short: it must have at least %d char's" % rr[0]) else: debug_print("OK - string has %d char's" % length) else: if (atom_val > tt(rr[1]) or atom_val < tt(rr[0])): ev.record_error( prefix, "Value out of range: expecting within %s, but got '%s'" % (str(rr), str(atom_val))) else: debug_print("OK - {} is within {}".format( str(atom_val), str(rr), )) # _check try: cc = valid._check except AttributeError: pass else: debug_print(" external checking...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](atom, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](atom, valid, ev, prefix) def _check_list(map, valid, ev, prefix, tag): """ Checks the validity of list. """ kk = map vv = valid # type debug_print(prefix + ":") debug_print(" checking its type...", False) try: t = ev(vv.type) tt = __TYPE[t] except AttributeError: ev.record_error( prefix, "Wrong type: expecting a composite parameter, but got a list") return if (tt != list): ev.record_error( prefix, "Wrong type: expecting %s, but got <type 'list'>" % str(tt)) return debug_print("OK - list") # size try: debug_print(" checking its size...", False) size = ev(vv.size) ll = len(kk) if (size > 0 and ll != size): ev.record_error( prefix, "Wrong list length: expecting %d elements, but got %d" % ( size, ll, )) elif (size < 0 and ll < -size): ev.record_error( prefix, "Wrong list length: expecting at least %d elements, but got %d" % ( -size, ll, )) else: debug_print("OK - %d" % ll) except AttributeError: pass # elem debug_print(" checking each element in list...", False) try: if (isinstance(vv.elem, List)): lv, lk = len(vv.elem), len(kk) [ _check_map(k, v, ev, ("%s[%d]" % (prefix, i)), tag) for i, k, v in zip(list(range(lv)), kk, vv.elem) ] if (lv < lk): v = vv.elem[-1] [ _check_map(kk[i], v, ev, ("%s[%d]" % (prefix, i)), tag) for i in range(lv, lk) ] else: [ _check_map(elem, vv.elem, ev, ("%s[%d]" % (prefix, i)), tag) for i, elem in enumerate(kk) ] except AttributeError: debug_print("OK - No requirement for elements") # _check try: cc = vv._check debug_print(" external checking for the whole list...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](map, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](map, valid, ev, prefix) except AttributeError: pass def _check_map(map, valid, ev, prefix="", tag=set()): # noqa: M511 """ Checks the validity of a map. """ # _if try: _if = ev(valid._if) except AttributeError: pass else: debug_print("_if: {} = {}".format( str(valid._if), _if, ), False) if (_if): debug_print("True") else: debug_print("False - Skip checking the whole map.") return if (isinstance(valid, List)): return _match(map, valid, ev, prefix, tag) if (isinstance(map, Atom)): _check_atom(map, valid, ev, prefix) elif (isinstance(map, List)): _check_list(map, valid, ev, prefix, tag) elif (isinstance(map, Map)): # _skip try: skip = valid._skip.val except AttributeError: skip = [] else: if (not isinstance(skip, list) and skip != "all"): raise ValueError( "_skip must be either a list of strings or the string \"all\"" ) # _mapcheck try: cc = valid._mapcheck except AttributeError: pass else: debug_print(prefix + ":") debug_print(" external checking for the whole map...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](map, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](map, valid, ev, prefix) # _enforce try: cc = valid._enforce except AttributeError: pass else: if (not isinstance(cc, List)): raise ValueError("_enforce must be a list of strings") debug_print(prefix + ":") debug_print(" enforcing keys...", False) missing_key = [e for e in cc.val if (e not in map)] missing_key = ", ".join(missing_key) if (missing_key == ""): debug_print("OK - All enforced keys present") else: debug_print("Error\nMissing keys: " + missing_key[0:-2]) ev.record_error(prefix, "Missing keys: " + missing_key[0:-2]) if ("all" != skip): # Key-value pairs key_value = [ (k, kk) for k, kk in map.key_value(tag) if (k not in skip) ] for k, kk in key_value: try: vv = valid[k] except KeyError: ev._unchecked_map.append(prefix + '.' + k) continue _check_map(kk, vv, ev, prefix + '.' + k, tag)