Source code for schrodinger.test.stu.outcomes.compare_csv

"""
See compare_csv docstring

$Revision 0.2 $

@copyright: (c) Schrodinger, LLC. All rights reserved
"""

import numpy as np
import pandas

from schrodinger.utils import log

DEFAULT_TOLERANCE = 0.005
"""Default tolerance if none is provided."""


[docs]def help(): return "Usage: compare_csv(test.csv, reference.csv, [tolerance], [lines])"
[docs]def compare_csv(test_file, ref_file, tolerance=None, reltol=None, lines=None, delimiter=None, ignore_cols=None, sort_by=None, skip_rows=None, comment=None): """ Workup for comparing two CSV files. Example use:: outcome_workup = compare_csv('test.csv', 'ref.csv', tolerance=0.05, lines=3, delimiter=' ') Numeric values will be compared using tolerance or reltol, as described below. Equality is required for values that cannot be cast as floats (strings for example). The lines in each CSV file are expected to line up (i.e. line 1 in test.csv is compared with line 1 from ref.csv). This means if a line is skipped in test.csv all subsequent lines will cause failures (so, many failure messages will be printed - one for each line after the skip). :param str test_file: Filename of csv to be tested. :param str ref_file: Filename of reference csv. :param float|None tolerance: Maximum possible deviation from the ref csv for numeric values. If this and reltol are both None, a default value will be used. :param float|None reltol: Maximum possible deviation from the ref csv, expressed as a relative value. For example if this is 0.02, values may be different by up to 2%. :param int lines: Number of lines to compare. Default is to compare all lines and require that the same number of lines to be in the reference and test files. :param str delimiter: Delimiter to use while reading the csv. The default delimiter is ','. :param list ignore_cols: List of column names to ignore :param sort_by: Before comparing, sort the ref_file and test_file based on the column name(s) specified in sort_by. :type sort_by: `list` of `str` :param skip_rows: Line numbers to skip specified by the list (0 indexed). If an integer is specified, number of lines to skip from the start of file. :type skip_rows: `list` or `int` :param str comment: Indicates that commented lines should not be parsed. If found at the beginning of a line, the line will be ignored altogether. This parameter must be a single character. """ if tolerance and reltol: msg = 'Only one of tolerance and reltol can be used.' raise AssertionError(msg) if tolerance is None and reltol is None: tolerance = DEFAULT_TOLERANCE if not np.isreal(tolerance) and not np.isreal(reltol): msg = "One of tolerance and reltol must be defined and numeric (found \"tolerance {}, reltol {}\")".format( tolerance, reltol) raise TypeError(msg) tolerance_args = { 'atol': tolerance } if tolerance is not None else { 'rtol': reltol } ref_df = pandas.read_csv(ref_file, delimiter=delimiter, skiprows=skip_rows, comment=comment) test_df = pandas.read_csv(test_file, delimiter=delimiter, skiprows=skip_rows, comment=comment) if ignore_cols: # Remove columns without raising if a column doesn't exist for df in (ref_df, test_df): df.drop(columns=ignore_cols, inplace=True, errors='ignore') violations = [] if not ref_df.shape[1] == test_df.shape[1]: msg = 'Number of columns do not match. Reference: {}, Test: {}' violations.append(msg.format(ref_df.shape[1], test_df.shape[1])) elif lines is None and not ref_df.shape[0] == test_df.shape[0]: msg = 'Number of rows do not match. Reference: {}, Test: {}' violations.append(msg.format(ref_df.shape[0], test_df.shape[0])) else: sorted_ref_df = ref_df.sort_values(sort_by) if sort_by else ref_df sorted_test_df = test_df.sort_values(sort_by) if sort_by else test_df violations.extend( compare_dfs(sorted_ref_df, sorted_test_df, tolerance_args, lines)) if violations: log_name = 'workup_compare_csv.log' log.logging_config(file=log_name, format='%(message)s', filemode='w') logger = log.get_logger(log_name) logger.warning(f'Errors comparing {test_file} and {ref_file}') for violation in violations: logger.warning(violation) raise AssertionError("FAILURE: File %s was different from %s\n" "Details can be found in %s" % (test_file, ref_file, log_name)) return True
[docs]def compare_dfs(ref_df, test_df, tolerance_args, lines): """ Compare test and reference dfs using specified tolerance. Return a list of any violations. """ violations = [] for (ref_index, ref_row), (test_index, test_row) in zip(ref_df.iterrows(), test_df.iterrows()): if lines and ref_index >= lines: break for ref_val, test_val in zip(ref_row.tolist(), test_row.tolist()): msg = None if not np.isreal(ref_val) or not np.isreal(test_val): if ref_val != test_val: msg = 'Error in row {}: value {} does not match reference value {}.' elif np.isnan(ref_val) and np.isnan(test_val): # Special case since nans are not equal to each other. pass elif not np.isclose(ref_val, test_val, **tolerance_args): msg = 'Error in row {}: value {} does not match reference value {} within tolerance' if msg: violations.append(msg.format(ref_index, ref_val, test_val)) return violations
if __name__ == "__main__": import sys assert len(sys.argv) == 3 ret = compare_csv(*sys.argv[1:]) if ret: print('Workflow passed')