Source code for schrodinger.application.matsci.mstest

"""
Unit test related functions/classes.

Copyright Schrodinger, LLC. All rights reserved.
"""

import contextlib
import os
import unittest

import numpy
import numpy.testing as npt

from schrodinger.application.matsci.nano import xtal
from schrodinger.infra import mm
from schrodinger.structure import Structure
from schrodinger.structutils import rmsd
from schrodinger import test as sdgr_test
from schrodinger.Qt.QtCore import Qt


[docs]def assert_same_struct(struct1, struct2, target_rmsd=0.0, decimal=1, msg=None): """ Assert that two structures are almost same. :param `structure.Structure` struct1: First structure :param `structure.Structure` struct2: Second structure :param float target_rmsd: Target RMSD for two structures :param int decimal: The number of decimal places used for comparing the two RMSD floats. In unittest needs to be set fairly loosely so that the tests pass on different OSs :param str or None msg: message shown if failure occurs :raise AssertionError: when the two struct are not equal within given tolerance """ if not struct1.atom_total == struct2.atom_total: msg = '{} != {}. {}'.format(struct1, struct2, msg or '') raise AssertionError(msg) vecs1 = numpy.array(xtal.get_vectors_from_chorus(struct1)) vecs2 = numpy.array(xtal.get_vectors_from_chorus(struct2)) npt.assert_almost_equal(vecs1, vecs2, 3) tosuperpose = list(range(1, struct1.atom_total + 1)) struct2_xyz = struct2.getXYZ() armsd = rmsd.superimpose(struct1, tosuperpose, struct2, tosuperpose, use_symmetry=False, move_which=rmsd.CT) # Revert coordinates of the struct2 after superimpose struct2.setXYZ(struct2_xyz) npt.assert_almost_equal(armsd, target_rmsd, decimal=decimal, err_msg=msg)
[docs]def mmexception_side_effect(): """ Get an MmException object that can be used as a side effect :rtype: mm.MmException :return: returns an mm.MmException object that can be used as a side effect for Mocks """ wrapper = unittest.mock.Mock() wrapper.function.__name__ = 'wrapper' # MmException needs 1) object with function.__name__, 2) a list of # arguments, and 3) a return code return mm.MmException(wrapper, [], 1)
[docs]class MSTestCase(unittest.TestCase):
[docs] def __init__(self, *args, **kwargs): """ See parent for documentation. When using addTypeEqualityFunc for a new object, always add it to assertAlmostEqual. """ # TODO: assertNotEqual should also use self._getAssertEqualityFunc # rather than using simple != comparision. super().__init__(*args, **kwargs) # Add structure equivalent check self.addTypeEqualityFunc(Structure, 'assertStructEqual') # Add numpy equivalent check. self.addTypeEqualityFunc(numpy.ndarray, 'numpy_equal')
[docs] def numpy_equal(self, actual, desired, msg=None): """ Call numpy equal """ standardMsg = f'{actual} != {desired}.' if not numpy.array_equal(actual, desired): self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertStructEqual(self, struct1, struct2, msg=None): """ Assert that two structures are equivalent. :param `structure.Structure` struct1: First structure :param `structure.Structure` struct2: Second structure :param str or None msg: message shown if failure occurs :raise AssertionError: when the two structures are not same. """ standardMsg = f'{struct1} != {struct2}.' if not struct1.atom_total == struct2.atom_total: self.fail(self._formatMessage(msg, standardMsg)) if not struct1.title == struct2.title: self.fail(self._formatMessage(msg, standardMsg)) if not struct1.isEquivalent(struct2): self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertStructAlmostEqual(self, struct1, struct2, target_rmsd=0.0, decimal=1, msg=None): """ Assert that two structures are almost same. :param `structure.Structure` struct1: First structure :param `structure.Structure` struct2: Second structure :param float target_rmsd: Target RMSD for two structures :param int decimal: The number of decimal places used for comparing the two RMSD floats. In unittest needs to be set fairly loosely so that the tests pass on different OSs :param str or None msg: message shown if failure occurs :raise AssertionError: when the two struct are not equal within given tolerance """ assert_same_struct(struct1=struct1, struct2=struct2, target_rmsd=target_rmsd, decimal=decimal, msg=msg)
[docs] def assertAlmostEqual(self, actual, desired, decimal=6, msg=None, places=None): """ Override assertAlmostEqual with numpy testing assert_almost_equal. Add functionality to test structure and dictionaries as well. :param array_like or dict or `structure.Structure` actual: the object to check :param array_like or dict or `structure.Structure` desired: the expected object :param str or None msg: message shown if failure occurs :param int decimal: desired precision :param int places: Provided for compatibility with unittest API, which uses the places keyword rather than decimal. If given, will override the decimal value """ if places is not None: decimal = places standardMsg = f'{actual} != {desired}' # Dictionary if self.typeTest(dict, actual, desired, msg): count = 0 for key, value in actual.items(): count += 1 if not key in desired: self.fail(self._formatMessage(msg, standardMsg)) npt.assert_almost_equal(value, desired[key], decimal=decimal, err_msg=msg) if count != len(desired.values()): self.fail(self._formatMessage(msg, standardMsg)) elif self.typeTest(Structure, actual, desired, msg): self.assertStructAlmostEqual(actual, desired, decimal=decimal, msg=msg) elif self.typeTest(numpy.ndarray, actual, desired, msg) or self.typeTest(list, actual, desired, msg): npt.assert_almost_equal(actual, desired, decimal, msg or '') else: super().assertAlmostEqual(actual, desired, decimal, msg)
[docs] def typeTest(self, obj_type, actual, desired, msg): """ Check if actual and desired are of given type. :param any class obj_type: Object type to check actual and desired is of :param any class actual: the actual object to check the type of :param any class desired: the desired object to check the type of :return bool: True if type of actual and desired is obj_type else False """ if isinstance(actual, obj_type): return True else: return False
[docs] def getCallResult(self, mocked_function, index=None, key=None, call_number=None, msg=None): """ Get the value for the passed argument for the mocked function :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: 0-based index of the argument to compare. :param str argument: Name of the keyword argument :param int call_number: Index of the call made to the function. None means the most recent call, calls are indexed starting at 1. :param str or None msg: message shown if failure occurs :return list or tuple or dict or set or frozenset or str: value for the passed argument for the mocked function :raise AssertionError: the mocked fuction was not called with input index or argument. """ self.assertTrue(mocked_function.called) if key is not None and index is not None: standardMsg = ('Cannot compare both keyword and non-keyword ' 'argument.') self.fail(self._formatMessage(msg, standardMsg)) if call_number is None: c_arg, c_kwarg = mocked_function.call_args else: if call_number == 0: self.fail('Calls are numbered starting at 1') if mocked_function.call_count < call_number: standardMsg = f'Function did not get called {call_number} times.' self.fail(self._formatMessage(msg, standardMsg)) else: c_arg, c_kwarg = mocked_function.call_args_list[call_number - 1] if key: if key in c_kwarg: return c_kwarg[key] else: standardMsg = f'Function not called with {key} argument.' self.fail(self._formatMessage(msg, standardMsg)) elif index: if index >= len(c_arg): standardMsg = f'Function does not contain {index + 1} arguments.' self.fail(self._formatMessage(msg, standardMsg)) return c_arg[index] else: return c_arg[0]
[docs] def assertCallEqual(self, actual, mocked_function, index=None, argument=None, call_number=None, msg=None): """ Check if the actual value is equal to value that the mocked fuction was called with. :param actual: actual value of the mock call to compare to :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: index of the argument to compare :param str argument: name of the keyword argument :param int or None call_number: Number of the call made to the function. If none the last call will be taken and 1 will be oldest (first) call :param str or None msg: message shown if failure occurs :raise AssertionError: If actual value is not equal to value that the mocked fuction was called with or the fuction was not called with input index/argument. """ desired = self.getCallResult(mocked_function, index, argument, call_number, msg) self.assertEqual(actual, desired, msg)
[docs] def assertCallNotEqual(self, actual, mocked_function, index=None, argument=None, call_number=None, msg=None): """ Check if the actual value is not equal to value that the mocked fuction was called with. :param actual: actual value of the mock call to compare to :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: Index of the argument to compare :param str argument: Name of the keyword argument :param int or None call_number: Number of the call made to the function. If none the last call will be taken and 1 will be oldest (first) call :param str or None msg: message shown if failure occurs :raise AssertionError: If actual value is equal to value that the mocked fuction was called with or the fuction was not called with input index/argument. """ desired = self.getCallResult(mocked_function, index, argument, call_number, msg) self.assertNotEqual(actual, desired)
[docs] def assertCallIn(self, actual, mocked_function, index=None, argument=None, call_number=None, msg=None): """ Check if the actual value is in the container that the mocked fuction was called with. :param actual: value to check if it was in the call argument :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: Index of the argument to compare :param str argument: Name of the keyword argument :param int or None call_number: Number of the call made to the function. If none the last call will be taken and 1 will be oldest (first) call :param str or None msg: message shown if failure occurs :raise AssertionError: If actual value is not in to value that the mocked fuction was called with or the fuction was not called with input index/argument. """ desired = self.getCallResult(mocked_function, index, argument, call_number, msg) self.assertIn(actual, desired, msg)
[docs] def assertCallNotIn(self, actual, mocked_function, index=None, argument=None, call_number=None, msg=None): """ Check if the actual value is not in the container that the mocked fuction was called with. :param actual: value to check if it was not in the call argument :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: Index of the argument to compare :param str argument: Name of the keyword argument :param int or None call_number: Number of the call made to the function. If none the last call will be taken and 1 will be oldest (first) call :raise AssertionError: If actual value is not in to value that the mocked fuction was called with or the fuction was not called with input index/argument. """ desired = self.getCallResult(mocked_function, index, argument, call_number, msg) self.assertNotIn(actual, desired, msg)
[docs] def assertCallAlmostEqual(self, actual, mocked_function, index=None, argument=None, call_number=None, msg=None, decimal=6): """ Check if the actual value is equal to value that the mocked fuction was called with. :param list or tuple or dict or set or frozenset or str actual: actual value of the mock call to compare to :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param int index: Index of the argument to compare :param str argument: Name of the keyword argument :param int or None call_number: Number of the call made to the function. If none the last call will be taken and 1 will be oldest (first) call :param str or None msg: message shown if failure occurs :param int decimal: desired precision :raise AssertionError: If actual value is not equal (within the decimal places) to value that the mocked fuction was called with or the fuction was not called with input index/argument. """ desired = self.getCallResult(mocked_function, index, argument, call_number, msg) if msg is None: msg = '' self.assertAlmostEqual(actual, desired, decimal=decimal, msg=msg)
[docs] def assertCalledTimes(self, times, mocked_function, msg=None): """ Check if the mocked fuction was called passed number of times :param int times: number of times mocked_function was called :param `unittest.mock.MagicMock` mocked_function: mocked function which was called :param str or None msg: message shown if failure occurs :raise AssertionError: If the mocked fuction was not called the passed amount of time """ self.assertEqual(times, mocked_function.call_count, msg)
[docs] def assertFlag(self, cmd, flag, value=None, is_in=True): """ Check that an expected flag is in (or not in) the command line list and also check the value of that flag. Note that the language used here is for the command line, but this function works for any sequence where one might want to simultaneously check the ith and i+1th values together. Examples: self.assertFlag(cmd, '-gpu') self.assertFlag(cmd, '-HOST', value='bolt_personal:5') self.assertFlag(cmd, '-nosystem', is_in=False) :param list cmd: The list to check :param flag: The item to check if it is in cmd (may be any type) :param value: If given, the item following flag will be checked to see if it is equal to value. :param bool is_in: If true, checks if flag is in cmd. If False, checks that flag is not in cmd, and value is ignored. :raise AssertionError: If flag violates the is_in setting, or if value is given and does not match the item after flag. """ if is_in: self.assertIn(flag, cmd) index = cmd.index(flag) else: self.assertNotIn(flag, cmd) return if value is not None: self.assertEqual(cmd[index + 1], value)
[docs] def assertTableContentsEqual(self, table, expected_contents, expected_column_headers=None): """ Assert that the actual table contents match the expected ones. Optionally checks table headers too. :param `QtWidgets.QTableWidget` table: The table to get actual contents from :param list expected_contents: List of lists containing the desired data in the table. Each inner list has the contents of a row. :param list expected_column_headers: If not None, the column headers in the table will be compared with this list :raise AssertionError: If the contents are not equal """ col_count = table.columnCount() if expected_column_headers is not None: actual_column_headers = [ table.horizontalHeaderItem(col).text() for col in range(col_count) ] self.assertSequenceEqual(actual_column_headers, expected_column_headers) for row in range(table.rowCount()): actual = [] for col in range(col_count): item = table.item(row, col) actual.append(item.text() if item is not None else '') self.assertSequenceEqual(actual, expected_contents[row])
[docs] def assertStageEqual(self, stage1, stage2): """ Assert that the two desmond stages are equal. :param `sea.Map` stage1: first stage :param `sea.Map` stage2: Second stage :raise AssertionError: If the contents are not equal """ default_values = { 'ensemble': 'NPT', 'annealing': 'false', 'temperature': 300.0, 'timestep': [0.002, 0.002, 0.006], 'pressure': None } def set_default_values(stage, key): try: stage[key] except KeyError: stage[key] = default_values[key] # Deal with barostat. It can be defined two by two different # method. Barostat can be defined by [press, barostat_type] or in # integrator. if key == 'pressure': try: isotropy = stage['backend'].integrator.pressure.isotropy.val except KeyError: isotropy = None if isotropy and isotropy == 'anisotropic': var = stage['pressure'] stage['pressure'] = [var, 'anisotropic'] return stage nocheck = set(('title', 'checkpt', 'jobname', 'backend')) combined_keys = set(stage1.keys()) combined_keys.update(stage2.keys()) combined_keys = combined_keys.difference(nocheck) for key in combined_keys: if key in default_values: stage1 = set_default_values(stage1, key) stage2 = set_default_values(stage2, key) self.assertIn(key, stage1) self.assertIn(key, stage2) self.assertEqual(stage1[key], stage2[key])
[docs]def matsci_scripts_testfile(path): """ Get the full path to the test file at mmshare/python/test/matsci/path :param str path: Relative path of file :rtype: str :return: Full path to file """ src = os.getenv('SCHRODINGER_SRC') if not src: raise TypeError('SCHRODINGER_SRC is not set.') return os.path.join(src, 'mmshare', 'python', 'test', 'matsci', path)
[docs]def matsci_application_testfile(path): """ Get the full path to the test file at mmshare/python/test/application/matsci/path :param str path: Relative path of file :rtype: str :return: Full path to file """ src = os.getenv('SCHRODINGER_SRC') if not src: raise TypeError('SCHRODINGER_SRC is not set.') return os.path.join(src, 'mmshare', 'python', 'test', 'application', 'matsci', path)
[docs]def load_structure(filename): """ Load the first structure from the given file in the testfile directory where structures are stored :param str filename: The name of the structure file :rtype: `structure.Structure` :return: The first structure in filename """ return Structure.read(sdgr_test.mmshare_testfile(filename))
[docs]@contextlib.contextmanager def mock_json_load_from_fh(module, contents_dict): """ Context manager to mock json.load on a file handle in the given module and return the given contents dictionary. :type module: module :param module: the module using open and json :type contents_dict: dict :param contents_dict: the contents dictionary """ mo = unittest.mock.mock_open() with unittest.mock.patch.object(module, 'open', mo, create=True): with unittest.mock.patch.object(module.json, 'load') as mjl: mjl.return_value = contents_dict yield
[docs]def set_wa_dont_show_on_screen(widget): """ Sets WA_DontShowOnScreen attribute to True for the given widget :param QWidget widget: The widget whose WA_DontShowOnScreen attribute is to be set. """ widget.setAttribute(Qt.WA_DontShowOnScreen, True)