Source code for schrodinger.utils.compare

"""
This is a postprocessor for the 'diff' command.  It hides differences
if they just represent numerical differences below some cutoff.  It
also ignores lines matching particular regular expressions, specified
in the file 'skip_lines'.
"""

import argparse
import pathlib
import re
import sys

NAN_TOKEN_TRANSLATION_DICT = {
    '1.#INF': 'inf',
    '-1.#INF': '-inf',
    '1.#INF0': 'inf',
    '-1.#INF0': '-inf',
    '-1.#IND': 'nan',
    '-1.#R': 'nan',
    'nan0x7fffffff': 'nan',
    'NaNQ': 'nan',
    'INF': 'inf',
}

REGEX_STRING_PATTERN = re.compile(r"^/(.*[^\\])/")
DIFF_KEY_PATTERN = re.compile(r"^[\d,]+([a-z])[\d,]+")
NUMERICAL_FIELD_PATTERN = re.compile(
    r"^[-+]?(\d+(\.\d*)?|\.\d+)([eEdD][-+]?\d+)?$")
WHITESPACE_AND_SYMBOLS_PATTERN = re.compile(
    r"([\s,;:|=_\(\)\{\}\[\]\*\/^\%#\@\$\"'`~]+)")


[docs]class DiffCompareArgumentParser(argparse.ArgumentParser):
[docs] def exit(self, status=0, message=None): if status != 0: raise argparse.ArgumentError(None, message)
[docs]def get_arg_parser(): parser = DiffCompareArgumentParser( description=__doc__, add_help=True, usage="diff <filea> <fileb> | compare [<options>] [<cutoff>]", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "-a", "-abs", const=False, dest="relative_error", action="store_const", default=True, help= "Compare absolute differences between differing fields. The default is to compare relative differences." ) parser.add_argument( "-r", "-rel", const=True, dest="relative_error", action="store_const", default=True, help= "Compare relative differences between differing fields. This is the default." ) parser.add_argument( "-c", const=True, dest="count_diffs", action="store_const", default=False, help="Print number of difference blocks on last line of output.") parser.add_argument( "-m", "-mag", dest="minimum_mag", action="store", type=float, default=1.0, help= "Minimum magnitude to use for assessing relative diffs. Must be positive. The default value is 1.0." ) parser.add_argument("-n", "-noskip", const="", dest="skip_file", action="store_const", default="skip_lines", help="Don't skip lines specified by skip file.") parser.add_argument( "-s", "-skip", dest="skip_file", action="store", type=str, default="skip_lines", help= "File containing regexps for lines to ignore. By default the file './skip_lines' is used." ) parser.add_argument( "-z", const=True, dest="ignore_sign", action="store_const", default=False, help="Compare magnitudes only (ignore sign differences).") parser.add_argument( "cutoff", action="store", nargs="?", type=float, default=1e-5, help= "Absolute or relative diff allowed between numerical fields. Must be a positive number. The default is 0.00001." ) return parser
[docs]def get_diff_chunks(diff_lines): diff_output_from_file1 = [] diff_output_from_file2 = [] diff_key = "" for line in diff_lines: line = line.strip() m = DIFF_KEY_PATTERN.search(line) if m is not None: if (diff_output_from_file1 or diff_output_from_file2) and diff_key: yield diff_key, diff_output_from_file1, diff_output_from_file2 diff_key = line diff_output_from_file1 = [] diff_output_from_file2 = [] if line.startswith("<"): diff_output_from_file1.append(line) elif line.startswith(">"): diff_output_from_file2.append(line) if (diff_output_from_file1 or diff_output_from_file2) and diff_key: yield diff_key, diff_output_from_file1, diff_output_from_file2
[docs]def compare_diff_chunks(args, skip_patterns, diff_lines): for diff_key, diff_output_from_file1, diff_output_from_file2 in get_diff_chunks( diff_lines): if "c" in diff_key: numerical_diff_output_from_file1, numerical_diff_output_from_file2 = compare_diff_lines( diff_key, diff_output_from_file1, diff_output_from_file2, skip_patterns, args) diff_output = "" if len(numerical_diff_output_from_file1) != 0: diff_output = "\n".join( numerical_diff_output_from_file1) + "\n---\n" + "\n".join( numerical_diff_output_from_file2) elif "a" in diff_key: skip_patterns.append(re.compile(r"^>$")) diff_output = process_nonchange_diffs(diff_output_from_file1, diff_output_from_file2, skip_patterns) else: diff_output = process_nonchange_diffs(diff_output_from_file1, diff_output_from_file2, skip_patterns) yield diff_key, diff_output
[docs]def process_nonchange_diffs(diff_output_from_file1, diff_output_from_file2, skip_patterns): diff_lines = [] for diff_line in diff_output_from_file1 + diff_output_from_file2: if can_skip_line(diff_line, skip_patterns): continue if diff_line.strip() == ">": # added blank lines continue diff_lines.append(diff_line) return "\n".join(diff_lines)
[docs]def can_skip_line(line, skip_patterns): return any(pattern.search(line) for pattern in skip_patterns)
[docs]def translate_nan_tokens_in_line(line): line_tokens = line.split() for index, token in enumerate(line_tokens): if token in NAN_TOKEN_TRANSLATION_DICT: line_tokens[index] = NAN_TOKEN_TRANSLATION_DICT[token] return " ".join(line_tokens)
[docs]def get_diff_tokens_in_line(diff_output_line): diff_output_line = translate_nan_tokens_in_line(diff_output_line) diff_output_line = diff_output_line.strip() diff_tokens = WHITESPACE_AND_SYMBOLS_PATTERN.split(diff_output_line) return diff_tokens
[docs]def convert_diff_token_to_float(diff_token): numerical_field_match = NUMERICAL_FIELD_PATTERN.search(diff_token) if numerical_field_match is not None: diff_token = re.sub(r"[EDd]", "e", diff_token) float_token = float(diff_token) return float_token return None
[docs]def numerical_diff_is_acceptable(file1_float_token, file2_float_token, args): if args.ignore_sign: file1_float_token = abs(file1_float_token) file2_float_token = abs(file2_float_token) numerical_diff = abs(file2_float_token - file1_float_token) if args.relative_error: magnitude = abs(file1_float_token) + abs(file2_float_token) magnitude = max(args.minimum_mag, magnitude) numerical_diff /= magnitude if magnitude > 0.0 else 1 return numerical_diff <= args.cutoff
[docs]def compare_diff_lines(diff_key, diff_output_from_file1, diff_output_from_file2, skip_patterns, args): if len(diff_output_from_file1) != len(diff_output_from_file2): return diff_output_from_file1, diff_output_from_file2 file1_diff_lines = [] file2_diff_lines = [] def append_diff_lines_to_output(file1_diff_line, file2_diff_line): file1_diff_lines.append(file1_diff_line) file2_diff_lines.append(file2_diff_line) for file1_diff_line, file2_diff_line in zip(diff_output_from_file1, diff_output_from_file2): if can_skip_line(file1_diff_line, skip_patterns) and can_skip_line( file2_diff_line, skip_patterns): continue file1_diff_tokens = get_diff_tokens_in_line(file1_diff_line) file2_diff_tokens = get_diff_tokens_in_line(file2_diff_line) if len(file1_diff_tokens) != len(file2_diff_tokens): append_diff_lines_to_output(file1_diff_line, file2_diff_line) continue for file1_diff_token, file2_diff_token in zip(file1_diff_tokens, file2_diff_tokens): if file1_diff_token == "<" and file2_diff_token == ">": continue file1_float_token = convert_diff_token_to_float(file1_diff_token) file2_float_token = convert_diff_token_to_float(file2_diff_token) # compare nonnumerical changes if file1_float_token is None or file2_float_token is None: if file1_diff_token != file2_diff_token: append_diff_lines_to_output(file1_diff_line, file2_diff_line) continue if not numerical_diff_is_acceptable(file1_float_token, file2_float_token, args): append_diff_lines_to_output(file1_diff_line, file2_diff_line) break return (file1_diff_lines, file2_diff_lines)
[docs]def get_regex_patterns_of_lines_to_skip(args): skip_patterns = [re.compile("^[<>] WARNING DEPRECATION")] if args.skip_file == "": return skip_patterns skip_file_path = pathlib.Path(args.skip_file) if not skip_file_path.is_file(): return skip_patterns with open(args.skip_file, "r") as fh: for line in fh.readlines(): line = line.strip() # remove leading and trailing spaces if line.startswith("#") or line.startswith("$"): continue regex_string_match = REGEX_STRING_PATTERN.search(line) if regex_string_match is not None: regex_pattern = regex_string_match.group(1) compiled_pattern = re.compile(regex_pattern) skip_patterns.append(compiled_pattern) return skip_patterns
[docs]def get_diff_output_with_numerical_tolerance(diff_lines, *options): arg_parser = get_arg_parser() args = arg_parser.parse_args((str(option) for option in options)) skip_patterns = get_regex_patterns_of_lines_to_skip(args) all_diff_output = [] exit_status = 0 for diff_key, diff_output in compare_diff_chunks(args, skip_patterns, diff_lines): if not diff_output: continue exit_status = 1 all_diff_output.append(diff_key + "\n" + diff_output) return "\n".join(all_diff_output), exit_status
[docs]def main(): diff_output, exit_status = get_diff_output_with_numerical_tolerance( sys.stdin, *sys.argv[1:]) if diff_output: print(diff_output) return exit_status
if __name__ == "__main__": sys.exit(main()) # exit 0 or 1