Source code for schrodinger.test.custom_assertions

"""
Contains custom assertions for use with unittest and Schrodinger data
structures.

copyright (c) Schrodinger, LLC. All rights reserved.
"""
import argparse
import numbers
import shlex
import unittest

import _pytest.assertion.util
from pytest import approx

from schrodinger.infra import mm
from schrodinger.structutils import smiles
from schrodinger.structutils.rmsd import ConformerRmsd

__unittest = True
"""Keeps stack trace from including this module."""

_DEFAULT_TOLERANCE = 0.005


[docs]def assertSameNumberOfAtoms(st1, st2): """ Check that two structures have the same number of atoms. Only exists to make reporting cleaner. :type st1: `schrodinger.structure.Structure` :param st1: First structure :type st2: `schrodinger.structure.Structure` :param st2: Second structure :rtype: None :raise: AssertionError: number of atom mismatch. """ len1 = st1.atom_total len2 = st2.atom_total if len1 != len2: raise AssertionError("Number of atoms mismatch: %d != %d" % (len1, len2))
[docs]def assertSameStructure(st1, st2, smiles_generator=None): """ Check that two structures have the same connectivity. :type st1: `schrodinger.structure.Structure` :param st1: First structure :type st2: `schrodinger.structure.Structure` :param st2: Second structure :type smiles_generator: `schrodinger.structutils.smiles.SmilesGenerator` :param smiles_generator: Optional smiles generator with specified options. If it is not specified, smiles are generated using STEREO_FROM_ANNOTATION_AND_GEOM) :raise AssertionError: number of atom/connectivity mismatch. """ assertSameNumberOfAtoms(st1, st2) if not smiles_generator: smiles_generator = smiles.SmilesGenerator( stereo=smiles.STEREO_FROM_ANNOTATION_AND_GEOM) smiles1 = smiles_generator.getSmiles(st1) smiles2 = smiles_generator.getSmiles(st2) if smiles1 != smiles2: raise AssertionError("SMILES mismatch: {} != {}".format( smiles1, smiles2))
[docs]def assertConformersAlmostEqual(st1, st2, max_rmsd=0.1): """ Check that two conformers (same atoms in the same order) are approximately equal by determining the RMSD. If they are not, or if the number of atoms differs or the SMILES patterns don't match, an AssertionError is raised. Ignores rotation and translation. :type st1: `schrodinger.structure.Structure` :param st1: First structure :type st2: `schrodinger.structure.Structure` :param st2: Second structure :type max_rmsd: float :param max_rmsd: Maximum RMSD for two structures to be considered "almost equal." :rtype: None :raise: AssertionError: large RMSD or number of atom/connectivity mismatch. """ assertSameStructure(st1, st2) rmsd = ConformerRmsd(st1, st2).calculate() if rmsd > max_rmsd: raise AssertionError("Structures have an RMSD of %.3d." % rmsd)
[docs]def assertEqualShFiles(file1, file2): """ Check that two files have the same Schrodinger command that are starting with "${SCHRODINGER}/run". The files to be compared should be the ones written by Appframework2 when a write() function is called. :type file1: str :param file1: File containing the first command :type file2: str :param file2: File containing the second command :rtype: None :raise: AssertionError: commands different """ # Compare commands in 2 files with open(file1) as ref_cmd_fd: ref_cmd = ref_cmd_fd.read().strip() with open(file2) as cmd_fd: cmd = cmd_fd.read().strip() # Ignore "${SCHRODINGER}/run", as the path separator will be OS specific # Strip all chars till the first space and compare if ref_cmd[ref_cmd.index(' '):] != cmd[cmd.index(' '):]: raise AssertionError( f"Commands did not match:\nFrom {file1}:\n{ref_cmd}\n" f"From{file2}:\n{cmd}")
[docs]def assertEqualCommandFiles(file1, file2, arg_parser, skip_count=0, ignore_options=None): """ Check if commands in two files are same. The files should have exactly one line which is the command to be compared. Since an ArgumentParser is used to compare the commands, ideally this function should be used to compare just the arguments and not the command name. :param file1: Path to command file1 :type file1: str :param file2: Path to command file2 :type file2: str :param arg_parser: Function that takes a list of command arguments and return an argparse.Namespace or a tuple (argparse.Namespace, list) if the arg_parser is a ArgumentParser.parse_known_args() :type arg_parser: func(list) :param skip_count: Skip items at the begining of the command. Default is 0. Can be used to skip the program name in the command. :type skip_count: int :param ignore_options: List of options to be ignored while comparing the commands. Default is not to ignore anything. :type ignore_options: list(options) """ args1 = [] with open(file1) as fh: args1 = shlex.split(fh.read()) args2 = [] with open(file2) as fh: args2 = shlex.split(fh.read()) assertEqualCommandArgs(args1[skip_count:], args2[skip_count:], arg_parser, ignore_options)
[docs]def assertEqualCommandArgs(args1, args2, arg_parser, ignore_options=None): """ Check if two commands passed as list of arguments are same. :param args1: Argument list1 :type args1: list :param args2: Argument list2 :type args2: list :param arg_parser: Function that takes a list of command arguments and return an argparse.Namespace or a tuple (argparse.Namespace, list) if the arg_parser is a ArgumentParser.parse_known_args() :type arg_parser: func(list) :param ignore_options: List of options to be ignored while comparing the commands. Default is not to ignore anything. :type ignore_options: list(options) """ def get_namespace(namespace_or_tuple): if type(namespace_or_tuple) is argparse.Namespace: return namespace_or_tuple if type(namespace_or_tuple) is tuple and type( namespace_or_tuple[0]) is argparse.Namespace: return namespace_or_tuple[0] raise ValueError( f'Invalid return type by arg_parser:{type(namespace_or_tuple)}') namespace1 = get_namespace(arg_parser(args1)) namespace2 = get_namespace(arg_parser(args2)) command_dict1 = vars(namespace1) command_dict2 = vars(namespace2) assertEqualDicts(command_dict1, command_dict2, ignore_options)
[docs]def assertEqualDicts(dict1, dict2, ignore_keys=None, tolerance=None): """ Compares two python dicts. :param dict1: First dict :type dict1: dict :param dict2: Second dict :type dict2: dict :param ignore_keys: List of keys to be ignored while comparing the dicts. Default is not to ignore anything. :type ignore_keys: list(keys) :param tolerance: Tolerance for comparing fractional numeric values. Pass None to always compare with `!=`. :type tolerance: float or None """ if ignore_keys == None: ignore_keys = [] # Check for the length of the dicts only if there is nothing to ignore if not ignore_keys and len(dict1) != len(dict2): raise AssertionError(f'dict lengths are not equal:\n{dict1}\n{dict2}\n') def compare_dicts(dict1, dict2, ignore_keys): for key in dict1: if key in ignore_keys: continue if key not in dict2: raise AssertionError(f'key {key} not present in both dicts') val1 = dict1[key] val2 = dict2[key] if tolerance and (_is_fractional(val1) or _is_fractional(val2)): assert val1 == approx(val2, abs=tolerance), key else: assert val1 == val2, key # Compare all keys of dict1 with dict2 and vice-versa to check if they # exactly have same set of keys compare_dicts(dict1, dict2, ignore_keys) compare_dicts(dict2, dict1, ignore_keys)
def _is_fractional(i): """ Return True if i is a real or rational number (but not an integer). """ return isinstance(i, numbers.Real) and not isinstance(i, numbers.Integral)
[docs]class StructureAssertionsTestCase(unittest.TestCase): """ "Convenience" class to allow structural assertions to be called similarly to built in assertions. """
[docs] def assertSameNumberOfAtoms(self, st1, st2): assertSameNumberOfAtoms(st1, st2)
[docs] def assertSameStructure(self, st1, st2): assertSameStructure(st1, st2)
[docs] def assertConformersAlmostEqual(self, st1, st2, max_rmsd=0.1): assertConformersAlmostEqual(st1, st2, max_rmsd=0.1)
class _CmpFloat: """Store a float with a custom tolerance for equality comparisons""" def __init__(self, value, tolerance=_DEFAULT_TOLERANCE): self._value = value self._tolerance = tolerance def __eq__(self, that): try: return abs(self._value - that._value) < self._tolerance except AttributeError: # When comparing to a float return abs(self._value - that) < self._tolerance def __neq__(self, that): return not (self == that) def __str__(self): return str(self._value) def __repr__(self): return repr(self._value) @classmethod def asCmp(cls, value, tolerance=_DEFAULT_TOLERANCE): """Give "value" a tolerance for equality comparison if it is a float""" if isinstance(value, float): return cls(value, tolerance) else: return value
[docs]def assert_properties_match(st1, st2, ignore=None, tolerance=_DEFAULT_TOLERANCE, msg=None): """ Do the Structure level properties of st1 and st2 match? Skips properties listed in ignore. Uses pytest-like formatting for the dictionary diff. tolerance is absolute Output like:: custom_assertions.assert_properties_match(st1, st2) E AssertionError: Omitting 2 identical items, use -vv to show E Right contains 1 more item: E {'i_user_test': 22} or:: custom_assertions.assert_properties_match(st1, st2) E AssertionError: Omitting 2 identical items, use -vv to show E Differing items: E {'i_user_test': 21} != {'i_user_test': 22} """ p1 = {k: _CmpFloat.asCmp(v, tolerance) for k, v in st1.property.items()} p1[mm.MMCT_STEREO_STATUS_PROP] = mm.mmct_ct_get_stereo_status(st1) if ignore: p1 = {k: v for k, v in p1.items() if k not in ignore} p2 = {k: _CmpFloat.asCmp(v, tolerance) for k, v in st2.property.items()} p2[mm.MMCT_STEREO_STATUS_PROP] = mm.mmct_ct_get_stereo_status(st2) if ignore: p2 = {k: v for k, v in p2.items() if k not in ignore} if p1 == p2: return explanation = _pytest.assertion.util._compare_eq_dict(p1, p2, True) if msg: explanation.insert(0, msg) raise AssertionError('\n'.join(explanation))