Source code for schrodinger.test.stu.outcomes.custom.ffld_workups

"""
Special STU workups for testing the OPLS force field and FFBuilder.

Copyright Schrodinger LLC, All Rights Reserved.
"""

import collections
import os
import sqlite3 as lite
import traceback
import zipfile
from pathlib import Path
from typing import Optional
from typing import Set

from schrodinger.forcefield import common
from schrodinger.forcefield import constants
from schrodinger.infra import mm
from schrodinger.test.stu.outcomes import failures
from schrodinger.test.stu.outcomes import structure_comparisons
from schrodinger.utils import fileutils
from schrodinger.utils import mmutil

try:
    from schrodinger.forcefield.packages import common_ffb
    FFB_ENCRYPTED_PROPS = [common_ffb.UTT_DEF_PROP]
    IGNORED_MAE_PROPS = ["s_m_Source_File", "s_m_Source_Path", "b_ff_acyclic", "b_ff_rotatable"] + \
        FFB_ENCRYPTED_PROPS
except ImportError as e:
    common_ffb = e
    FFB_ENCRYPTED_PROPS = None
    IGNORED_MAE_PROPS = None

try:
    from schrodinger.forcefield.packages import utt_ffb
except ImportError as e:
    utt_ffb = e

try:
    from schrodinger.forcefield.packages.test import utt_comparison
except ImportError as e:
    utt_comparison = e

UTT_DEF_STRING_NAMES = ["param_def", "param_fit"]
UTT_DEF_STRING_NAMES_DB_V2 = ["param_def", "enc_def"]
DB_NOT_RECOGNIZED_MSG = (
    f'Database name not recognized!\n'
    f'Expecting {constants.FFBUILDER_DB_V2} or {constants.FFBUILDER_SQL}.')


[docs]class DBNotRecognizedError(ValueError): pass
def _pull_db_data(dbname, ignore_tables: Optional[Set[str]] = None): """ Extract data from database. :param ignore_table: names of tables to be ignored. """ with lite.connect(os.path.normpath(dbname)) as db: # get all table names: cursor = db.execute( "SELECT name FROM sqlite_master WHERE type='table';") tables = {table_name for table_name, *_ in cursor.fetchall()} if ignore_tables: tables -= ignore_tables for table_name in sorted(tables): # get list of column names, and sort to canonicalize cursor = db.execute(f"PRAGMA table_info({table_name})") column_list = sorted(x[1] for x in cursor) # retrieve row data in canonical column order column_spec = ",".join(column_list) select_cmd = f"select {column_spec} from {table_name}" cursor = db.execute(select_cmd) row_list = cursor.fetchall() yield table_name, column_list, row_list def _assert_same_column_names(ref_column_list, test_column_list): ref_name_set = set(ref_column_list) test_name_set = set(test_column_list) excess_ref_name_list = sorted(ref_name_set - test_name_set) excess_test_name_list = sorted(test_name_set - ref_name_set) errmsg = 'Columns are not as expected --\n\n' if excess_ref_name_list: errmsg += f'Only ref db has column(s):\n{excess_ref_name_list}\n\n' if excess_test_name_list: errmsg += f'Only test db has column(s):\n{excess_test_name_list}\n\n' errmsg += f'Reference:\n{ref_column_list}\n\n' errmsg += f'Test:\n{test_column_list}\n' assert test_column_list == ref_column_list, errmsg def _assert_same_row_counts(ref_row_list, test_row_list): num_ref_rows = len(ref_row_list) num_test_rows = len(test_row_list) errmsg = 'Row count is not as expected --\n' errmsg += f'Reference row count: {num_ref_rows}\n' errmsg += f'Test row count: {num_test_rows}' assert num_test_rows == num_ref_rows, errmsg def _assert_same_row_values(dbname, column_list, ref_row_list, test_row_list): errors = [] for nrow, (ref_row, test_row) in enumerate(zip(ref_row_list, test_row_list), 1): for name, ref_value, test_value in zip(column_list, ref_row, test_row): if dbname == constants.FFBUILDER_DB_V2: if name in UTT_DEF_STRING_NAMES_DB_V2 and ref_value and test_value: ref_value = common_ffb.decode_utt_def_string(ref_value) test_value = common_ffb.decode_utt_def_string(test_value) elif dbname == constants.FFBUILDER_SQL: # For UTT def string columns, arrange to tolerate differences # due to encoding only. For those columns, must first check # that pulled value is not None, because at least one test # (#29533) has param_fit = NULL in one row (correctly). if name in UTT_DEF_STRING_NAMES and ref_value and test_value: ref_value = common_ffb.decode_utt_def_string(ref_value) test_value = common_ffb.decode_utt_def_string(test_value) else: raise DBNotRecognizedError(DB_NOT_RECOGNIZED_MSG) if test_value != ref_value: errors.append(f"{nrow}, {name}, {test_value}, {ref_value}") header = "Row, Column Name, Test Value, Reference Value\n" assert not errors, header + "\n".join(errors)
[docs]def compare_ffb_sql(check_dbfile, reference_dbfile): """ Compares 'ffb' tables only, not 'meta' tables. """ basename = os.path.basename(check_dbfile) if basename == constants.FFBUILDER_DB_V2: ignore_tables = {'run'} elif basename == constants.FFBUILDER_SQL: ignore_tables = {'meta'} else: raise DBNotRecognizedError(DB_NOT_RECOGNIZED_MSG) for ref_data, test_data in zip( _pull_db_data(reference_dbfile, ignore_tables=ignore_tables), _pull_db_data(check_dbfile, ignore_tables=ignore_tables)): ref_table_name, ref_column_list, ref_row_list = ref_data test_table_name, test_column_list, test_row_list = test_data if ref_table_name != test_table_name: raise failures.WorkupFailure( f"Table names don't match: " f"test_table_name={test_table_name} ref_table_name={ref_table_name}" ) try: _assert_same_column_names(ref_column_list, test_column_list) _assert_same_row_counts(ref_row_list, test_row_list) _assert_same_row_values(basename, ref_column_list, ref_row_list, test_row_list) except AssertionError as exc: msg = f"Database '{ref_table_name}' tables do not match:\n" msg += f"{reference_dbfile}\n{check_dbfile}\n\n{exc}" raise failures.WorkupFailure(msg) from exc # Return True if all tables are identical return True
def _get_decoded_property(st, prop): value = st.property.get(prop) if value is None: return value return common_ffb.decode_utt_def_string(value) def _compare_encoded_properties(check, reference): """ Compares ffld encrypted property between two maestro files """ failure_msgs = collections.defaultdict(list) for i, st1, st2 in structure_comparisons._zip_structures(check, reference): for prop in FFB_ENCRYPTED_PROPS: value1 = _get_decoded_property(st1, prop) value2 = _get_decoded_property(st2, prop) if value1 != value2: failure_msgs[f"Structure {i:02d}"] = \ f"{prop} mismatch: {value1} != {value2}" if failure_msgs: msg = ("Maestro encrypted properties did not match in %s and %s." % (check, reference)) raise structure_comparisons.CompositeFailure(msg, failure_msgs)
[docs]def compare_ffbuilder_mae_files(check, reference, *properties, **tolerances): """ Compares maestro files output from ffbuilder jobs. This is effectively the same call as compare_mae_files, but handles encoded UTT properties. """ for prop in IGNORED_MAE_PROPS: tolerances[prop] = structure_comparisons.IGNORE structure_comparisons.compare_mae_files(check, reference, *properties, **tolerances) # Compare the encoded properties that were previously ignored _compare_encoded_properties(check, reference) return True
[docs]@common.mmffld_environment() def compare_parameter_files(check, reference, v_term_tol=0.0): """ Compares individual OPLS parameter files. """ if not os.path.isfile(check) or not os.path.isfile(reference): msg = f"Both {check} and {reference} must exist" raise failures.WorkupFailure(msg) check_utt_data = utt_ffb.read_uttdat_file(check) ref_utt_data = utt_ffb.read_uttdat_file(reference) try: utt_comparison.assert_same_utt_data(check_utt_data, ref_utt_data, v_term_tol) except AssertionError as exc: # get traceback since currently asserts have no messages tb_lines = "".join(traceback.format_exc()) msg = "OPLS_DIR parameter files do not match:\n" msg += f"{check}\n{reference}\n{tb_lines}" raise failures.WorkupFailure(msg) return True
def _extract_project_archive(project_archive): """ Extracts the given zipfile into a directory from the archive basename. :param project_archive: ffbuilder project archive :type project_archive: str :return: directory which the project archive was extracted into :rtype: str """ extract_dirname = fileutils.get_jobname(project_archive) # allow extracting into existing directory so we can run workup more than # once if not os.path.isdir(extract_dirname): os.mkdir(extract_dirname) with zipfile.ZipFile(project_archive) as zf: zf.extractall(extract_dirname) return extract_dirname def _collect_and_compare_filenames(check_dir, reference_dir): if isinstance(common_ffb, ImportError): raise common_ffb # ignore MERGED_SUBDIR_ORIGIN file so we can support FFB_PER_FRAGMENT on/off # can be removed with FFB_PER_FRAGMENT as this function will only be used # for min dir get_filenames = lambda dirname: { str((Path(dirpath) / filename).relative_to(dirname)) for dirpath, _, filenames in os.walk(dirname) for filename in filenames if filename != common_ffb.MERGED_SUBDIR_ORIGIN } check_files = get_filenames(check_dir) reference_files = get_filenames(reference_dir) diff_msg = lambda a, b: "\n " + "\n ".join(sorted(a - b)) if check_files != reference_files: msg = "Existing files do not match; unique files present:\n\n" msg += f"{check_dir}: {diff_msg(check_files, reference_files)}\n\n" msg += f"{reference_dir}: {diff_msg( reference_files, check_files)}\n\n" raise failures.WorkupFailure(msg) return check_files
[docs]def compare_ffb_proj_zip(check, reference): """ Compares force field builder project directory archives, first unzipping them to separate subdirectories, and then comparing toplevel files listed in KNOWN_OUTPUT_FILES with the appropriate file comparison routine. """ # Create subdirs and extract archive contents in each check_dir = _extract_project_archive(check) reference_dir = _extract_project_archive(reference) if mmutil.feature_flag_is_enabled(mmutil.FFBUILDER_DB_V2): ffdb_file = constants.FFBUILDER_DB_V2 else: ffdb_file = constants.FFBUILDER_SQL files_comparisons = [ (ffdb_file, compare_ffb_sql), (constants.FFBUILDER_FRAGS_FILE, compare_ffbuilder_mae_files), ] for filename, comparison in files_comparisons: check_filename = os.path.join(check_dir, filename) reference_filename = os.path.join(reference_dir, filename) # Ignore if files not present in either if not os.path.isfile(check_filename) and not os.path.isfile( reference_filename): continue # Fail if files not present in both. if not (os.path.isfile(check_filename) and os.path.isfile(reference_filename)): raise failures.WorkupFailure( f"{filename} should be present in both test and reference directories!" ) comparison(check_filename, reference_filename) # Check existence of all files in the min tree by checking that all paths # are present in both directories check = os.path.join(check_dir, constants.MINDIR) reference = os.path.join(reference_dir, constants.MINDIR) _collect_and_compare_filenames(check, reference) return True
[docs]def compare_OPLS_DIR(check, reference, v_term_tol=0.0): """ Compares force field builder project output OPLS_DIR directories. Currently this only compares the OPLS3e f16_utt.dat files. """ check_file = mm.get_archive_path(check) # test opls archive is any file with opls extension # (name is release-independent) reference_files = sorted(Path(reference).glob('*.opls')) if len(reference_files) != 1: raise failures.WorkupFailure( "Expected exactly one reference opls archive, " f"got {reference_files}") compare_parameter_files(check_file, reference_files[0], v_term_tol) return True