Source code for schrodinger.models.diffy

"""
Need a diff in a jiffy? Use diffy!
"""
import collections
from collections import abc
from functools import singledispatch

from schrodinger.utils.scollections import IdDict
from schrodinger.utils.scollections import IdSet
"""
Generic Functions
"""


[docs]@singledispatch def get_diff(new_state, old_state): """ Given two states of an object, calculate what changed between them. """ err_msg = f'get_diff not implemented for type {type(old_state).__name__}' raise NotImplementedError(err_msg)
[docs]@singledispatch def get_removed(new_state, old_state): err_msg = f'get_removed not implemented for type {type(old_state).__name__}' raise NotImplementedError(err_msg)
[docs]@singledispatch def get_added(new_state, old_state): err_msg = f'get_added not implemented for type {type(old_state).__name__}' raise NotImplementedError(err_msg)
[docs]@singledispatch def get_updated(new_state, old_state): err_msg = f'get_updated not implemented for type {type(old_state).__name__}' raise NotImplementedError(err_msg)
[docs]@singledispatch def get_moved(new_state, old_state): err_msg = f'get_updated not implemented for type {type(old_state).__name__}' raise NotImplementedError(err_msg)
""" Implementations of generic functions. """ ListDiff = collections.namedtuple('ListDiff', 'added removed moved')
[docs]@get_diff.register(list) def get_diff_list(new_state, old_state): """ Calculate what was added, removed, and moved between two states of a list. Note that items are compared by identity not equality (ie `is` rather than `==`). :return: A namedtuple describing what was added, removed, and moved between two lists. See `get_added`, `get_removed`, and `get_moved` more details. :rtype: ListDiff(set, set, set) """ added = get_added(new_state, old_state) removed = get_removed(new_state, old_state) moved = get_moved(new_state, old_state) return ListDiff(added, removed, moved)
SetDiff = collections.namedtuple('SetDiff', 'added removed')
[docs]@get_diff.register(set) def get_diff_set(new_state, old_state): """ Calculate what was added and removed between two states of a set. :return: A namedtuple describing what was added and removed. :rtype: SetDiff(set, set) """ added = get_added(new_state, old_state) removed = get_removed(new_state, old_state) return SetDiff(added, removed)
[docs]@get_removed.register(set) def get_removed_set(new_state, old_state): """ Calculate what was removed between two states of a set. :rtype: set """ old_state, new_state = set(old_state), set(new_state) return old_state - new_state
[docs]@get_removed.register(list) def get_removed_list(new_state, old_state): """ :return: A set of tuples, each describing an item that was removed and and its index in `old_state` :rtype: set((object, int)) """ raw_removed = { _HashableTuple((o, i)) for i, (o, n) in enumerate(zip(old_state, new_state)) if o != n } for idx, item in enumerate(old_state[len(new_state):], len(new_state)): raw_removed.add(_HashableTuple((item, idx))) moved = get_moved(new_state, old_state) true_removed = raw_removed.difference( [_HashableTuple((item, old_state.index(item))) for item, idx in moved]) return true_removed
[docs]@get_added.register(list) def get_added_list(new_state, old_state): """ :return: A set of tuples, each describing an item that was added and and its index in `new_state`. :rtype: set((object, int)) """ raw_added = { _HashableTuple((n, idx)) for idx, (o, n) in enumerate(zip(old_state, new_state)) if o != n } for idx, item in enumerate(new_state[len(old_state):], len(old_state)): raw_added.add(_HashableTuple((item, idx))) moved = get_moved(new_state, old_state) true_added = raw_added.difference( [_HashableTuple((item, idx)) for item, idx in moved]) return true_added
[docs]@get_added.register(set) def get_added_set(new_state, old_state): """ Calculate what was removed between two states of a set. :rtype: set """ return new_state - old_state
DictDiff = collections.namedtuple('DictDiff', 'added removed updated')
[docs]@get_diff.register(dict) def get_diff_dict(new_state, old_state): """ Return dictionary items that have been added, removed, and updated. :return: A namedtuple describing what was added, removed, and moved between two dicts. See `get_added`, `get_removed`, and `get_updated` more details. :rtype: DictDiff(dict, dict, dict) """ added = get_added(new_state, old_state) removed = get_removed(new_state, old_state) updated = get_updated(new_state, old_state) return DictDiff(added, removed, updated)
[docs]@get_added.register(dict) def get_added_dict(new_state, old_state): """ :return: A dictionary with items in `new_state` but not in `old_state`. """ return {k: new_state[k] for k in new_state if k not in old_state}
[docs]@get_removed.register(dict) def get_removed_dict(new_state, old_state): """ :return: A dictionary with items in `old_state` but not in `new_state`. """ return {k: v for k, v in old_state.items() if k not in new_state}
[docs]@get_updated.register(dict) def get_updated_dict(new_state, old_state): """ :return: A dictionary with values that have changed from `old_state` to `new_state`. The values in the returned dictionary will be those from `new_state`. """ return { k: new_state[k] for k, v in old_state.items() if k in new_state and new_state[k] != v }
[docs]@get_moved.register(list) def get_moved_list(new_state, old_state): """ :return: A set of tuples, each describing an item that was moved and and its index in `new_state` :rtype: set((object, int)) """ item_to_new_idx = IdDict.fromIterable( (item, idx) for idx, item in enumerate(new_state)) for idx, item in enumerate(old_state): if item_to_new_idx.get(item) == idx: item_to_new_idx.pop(item) shared_items = IdSet(old_state) shared_items = shared_items.intersection(IdSet(item_to_new_idx.keys())) return { _HashableTuple((item, item_to_new_idx[item])) for item in shared_items }
class _HashableTuple(tuple): def __hash__(self): """ This will return an identical hash value to a regular tuple if this tuple is already made up of hashable items. If any items are not hashable, we use the id of the value instead. """ hashes = [] for idx, item in enumerate(self): try: hashes.append(hash(item)) except TypeError: # If the item is unhashable, settle for its id. hashes.append(id(item)) return hash(tuple(hashes)) def __eq__(self, other): if len(self) != len(other): return False else: for s_item, o_item in zip(self, other): if not isinstance(s_item, abc.Hashable): if s_item is not o_item: return False else: if s_item != o_item: return False return True