"""
Need a diff in a jiffy? Use diffy!
"""
import collections
from collections import abc
from functools import singledispatch
from schrodinger.utils.scollections import IdDict
from schrodinger.utils.scollections import IdSet
"""
Generic Functions
"""
[docs]@singledispatch
def get_diff(new_state, old_state):
"""
Given two states of an object, calculate what changed between them.
"""
err_msg = f'get_diff not implemented for type {type(old_state).__name__}'
raise NotImplementedError(err_msg)
[docs]@singledispatch
def get_removed(new_state, old_state):
err_msg = f'get_removed not implemented for type {type(old_state).__name__}'
raise NotImplementedError(err_msg)
[docs]@singledispatch
def get_added(new_state, old_state):
err_msg = f'get_added not implemented for type {type(old_state).__name__}'
raise NotImplementedError(err_msg)
[docs]@singledispatch
def get_updated(new_state, old_state):
err_msg = f'get_updated not implemented for type {type(old_state).__name__}'
raise NotImplementedError(err_msg)
[docs]@singledispatch
def get_moved(new_state, old_state):
err_msg = f'get_updated not implemented for type {type(old_state).__name__}'
raise NotImplementedError(err_msg)
"""
Implementations of generic functions.
"""
ListDiff = collections.namedtuple('ListDiff', 'added removed moved')
[docs]@get_diff.register(list)
def get_diff_list(new_state, old_state):
"""
Calculate what was added, removed, and moved between two states of a list.
Note that items are compared by identity not equality (ie `is` rather than
`==`).
:return: A namedtuple describing what was added, removed, and moved between
two lists. See `get_added`, `get_removed`, and `get_moved` more details.
:rtype: ListDiff(set, set, set)
"""
added = get_added(new_state, old_state)
removed = get_removed(new_state, old_state)
moved = get_moved(new_state, old_state)
return ListDiff(added, removed, moved)
SetDiff = collections.namedtuple('SetDiff', 'added removed')
[docs]@get_diff.register(set)
def get_diff_set(new_state, old_state):
"""
Calculate what was added and removed between two states of a set.
:return: A namedtuple describing what was added and removed.
:rtype: SetDiff(set, set)
"""
added = get_added(new_state, old_state)
removed = get_removed(new_state, old_state)
return SetDiff(added, removed)
[docs]@get_removed.register(set)
def get_removed_set(new_state, old_state):
"""
Calculate what was removed between two states of a set.
:rtype: set
"""
old_state, new_state = set(old_state), set(new_state)
return old_state - new_state
[docs]@get_removed.register(list)
def get_removed_list(new_state, old_state):
"""
:return: A set of tuples, each describing an item that was removed and
and its index in `old_state`
:rtype: set((object, int))
"""
raw_removed = {
_HashableTuple((o, i))
for i, (o, n) in enumerate(zip(old_state, new_state))
if o != n
}
for idx, item in enumerate(old_state[len(new_state):], len(new_state)):
raw_removed.add(_HashableTuple((item, idx)))
moved = get_moved(new_state, old_state)
true_removed = raw_removed.difference(
[_HashableTuple((item, old_state.index(item))) for item, idx in moved])
return true_removed
[docs]@get_added.register(list)
def get_added_list(new_state, old_state):
"""
:return: A set of tuples, each describing an item that was added and
and its index in `new_state`.
:rtype: set((object, int))
"""
raw_added = {
_HashableTuple((n, idx))
for idx, (o, n) in enumerate(zip(old_state, new_state))
if o != n
}
for idx, item in enumerate(new_state[len(old_state):], len(old_state)):
raw_added.add(_HashableTuple((item, idx)))
moved = get_moved(new_state, old_state)
true_added = raw_added.difference(
[_HashableTuple((item, idx)) for item, idx in moved])
return true_added
[docs]@get_added.register(set)
def get_added_set(new_state, old_state):
"""
Calculate what was removed between two states of a set.
:rtype: set
"""
return new_state - old_state
DictDiff = collections.namedtuple('DictDiff', 'added removed updated')
[docs]@get_diff.register(dict)
def get_diff_dict(new_state, old_state):
"""
Return dictionary items that have been added, removed, and updated.
:return: A namedtuple describing what was added, removed, and moved between
two dicts. See `get_added`, `get_removed`, and `get_updated` more
details.
:rtype: DictDiff(dict, dict, dict)
"""
added = get_added(new_state, old_state)
removed = get_removed(new_state, old_state)
updated = get_updated(new_state, old_state)
return DictDiff(added, removed, updated)
[docs]@get_added.register(dict)
def get_added_dict(new_state, old_state):
"""
:return: A dictionary with items in `new_state` but not in `old_state`.
"""
return {k: new_state[k] for k in new_state if k not in old_state}
[docs]@get_removed.register(dict)
def get_removed_dict(new_state, old_state):
"""
:return: A dictionary with items in `old_state` but not in `new_state`.
"""
return {k: v for k, v in old_state.items() if k not in new_state}
[docs]@get_updated.register(dict)
def get_updated_dict(new_state, old_state):
"""
:return: A dictionary with values that have changed from `old_state` to
`new_state`. The values in the returned dictionary will be those
from `new_state`.
"""
return {
k: new_state[k]
for k, v in old_state.items()
if k in new_state and new_state[k] != v
}
[docs]@get_moved.register(list)
def get_moved_list(new_state, old_state):
"""
:return: A set of tuples, each describing an item that was moved and
and its index in `new_state`
:rtype: set((object, int))
"""
item_to_new_idx = IdDict.fromIterable(
(item, idx) for idx, item in enumerate(new_state))
for idx, item in enumerate(old_state):
if item_to_new_idx.get(item) == idx:
item_to_new_idx.pop(item)
shared_items = IdSet(old_state)
shared_items = shared_items.intersection(IdSet(item_to_new_idx.keys()))
return {
_HashableTuple((item, item_to_new_idx[item])) for item in shared_items
}
class _HashableTuple(tuple):
def __hash__(self):
"""
This will return an identical hash value to a regular tuple if this
tuple is already made up of hashable items. If any items are not hashable,
we use the id of the value instead.
"""
hashes = []
for idx, item in enumerate(self):
try:
hashes.append(hash(item))
except TypeError:
# If the item is unhashable, settle for its id.
hashes.append(id(item))
return hash(tuple(hashes))
def __eq__(self, other):
if len(self) != len(other):
return False
else:
for s_item, o_item in zip(self, other):
if not isinstance(s_item, abc.Hashable):
if s_item is not o_item:
return False
else:
if s_item != o_item:
return False
return True