Source code for schrodinger.utils.scollections

import copy
from collections import abc

import more_itertools


[docs]def split_list(input_list, num_chunks): """ Split a list into N equal chunks. Note: function is similar to numpy.split_array. :param input_list: The list to split :type input_list: list :param num_chunks: The desired number of chunks :type num_chunks: int """ if not input_list: return [input_list] num_items = len(input_list) if num_items < num_chunks: empty_entries = [[]] * (num_chunks - num_items) return list(more_itertools.sliced(input_list, 1)) + empty_entries chunk_size = (num_items // num_chunks) + (num_items % num_chunks) return list(more_itertools.sliced(input_list, chunk_size))
[docs]class IdSet(abc.MutableSet, set): """ An id set is a set that uses the id of an object as the key instead of the hash. This means that two objects that compare equal but are different instances will be stored separately since id(obj1) != id(obj2). NOTE: Using this set with python's builtin immutable datatypes is /strongly/ discouraged (e.g. str and int are not guaranteed to have different ids for separate instances) """
[docs] def __init__(self, initial_set=None): self._id_to_obj_map = {} if initial_set is not None: self.update(initial_set)
[docs] def __contains__(self, obj): obj_id = id(obj) return set.__contains__(self, obj_id)
def __iter__(self): return iter(self._id_to_obj_map.values())
[docs] def __len__(self): return set.__len__(self)
def __copy__(self): return IdSet(self) def __deepcopy__(self): raise TypeError("Deepcopy is incompatible with IdSet")
[docs] def copy(self): # NB default copy method does not call __copy__ return copy.copy(self)
[docs] @classmethod def fromIterable(cls, iterable): new_set = cls() for item in iterable: new_set.add(item) return new_set
[docs] def isdisjoint(self, other): self._checkOtherSets(other) raise NotImplementedError() return set.isdisjoint(self, other)
[docs] def issubset(self, other): self._checkOtherSets(other) return set.issubset(self, other)
[docs] def issuperset(self, other): self._checkOtherSets(other) return set.issuperset(self, other)
[docs] def union(self, *other_sets): self._checkOtherSets(*other_sets) raise NotImplementedError() return set.union(self, *other_sets)
[docs] def intersection(self, *other_sets): self._checkOtherSets(*other_sets) id_intersection = set.intersection(self, *other_sets) return self.fromIterable( self._id_to_obj_map[obj_id] for obj_id in id_intersection)
[docs] def difference(self, *other_sets): self._checkOtherSets(*other_sets) id_difference = set.difference(self, *other_sets) return self.fromIterable( self._id_to_obj_map[obj_id] for obj_id in id_difference)
[docs] def symmetric_difference(self, *other_sets): self._checkOtherSets(*other_sets) raise NotImplementedError()
[docs] def update(self, *other_sets): self._checkOtherSets(*other_sets) for o_set in other_sets: for itm in o_set: self.add(itm)
[docs] def intersection_update(self, *other_sets): self._checkOtherSets(*other_sets) raise NotImplementedError()
[docs] def difference_update(self, *other_sets): self._checkOtherSets(*other_sets) raise NotImplementedError()
[docs] def symmetric_difference_update(self, *other_sets): self._checkOtherSets(*other_sets) raise NotImplementedError()
[docs] def add(self, obj): obj_id = id(obj) self._id_to_obj_map[obj_id] = obj set.add(self, obj_id)
[docs] def discard(self, obj): obj_id = id(obj) if obj_id in self._id_to_obj_map: del self._id_to_obj_map[obj_id] set.discard(self, obj_id)
def _checkOtherSets(self, *other_sets): any_set_builtin = any( isinstance(o_set, set) and not isinstance(o_set, IdSet) for o_set in other_sets) if any_set_builtin: raise ValueError('Set operations only supported with other IdSets')
[docs]class IdItemsView(abc.ItemsView):
[docs] def __init__(self, id_dict, id_map): self.id_dict = id_dict self.id_map = id_map self.id_map_items = None
[docs] def __contains__(self, item): k, v = item return k in self.id_dict and self.id_dict[k] == v
def __iter__(self): self.id_map_items = iter(self.id_map.items()) return self def __next__(self): obj_id, obj = next(self.id_map_items) return obj, self.id_dict[obj]
[docs] def __len__(self): return len(self.id_dict)
[docs]class IdDict(abc.MutableMapping, dict): """ An id dict is a dictionary that uses the id of an object as the key instead of the hash. This means that two objects that compare equal but are different instances will be stored separately since id(obj1) != id(obj2). NOTE: Using this dict with python's builtin immutable datatypes is /strongly/ discouraged (e.g. str and int are not guaranteed to have different ids for separate instances) """
[docs] def __init__(self, initial_dict=None): self._id_to_obj_map = {} if initial_dict is not None: if not isinstance(initial_dict, IdDict): err_msg = 'IdDict can only be initialized with another IdDict' raise ValueError(err_msg) self.update(initial_dict)
def __getitem__(self, obj): obj_id = id(obj) try: return dict.__getitem__(self, obj_id) except KeyError: raise KeyError(str(obj)) from None def __setitem__(self, obj, value): obj_id = id(obj) self._id_to_obj_map[obj_id] = obj dict.__setitem__(self, obj_id, value) def __delitem__(self, obj): obj_id = id(obj) del self._id_to_obj_map[obj_id] dict.__delitem__(self, obj_id)
[docs] def setdefault(self, key, default): cur_val = self.get(key) if cur_val is None: self[key] = default return self[key]
[docs] def __contains__(self, obj): obj_id = id(obj) return dict.__contains__(self, obj_id)
[docs] def __len__(self): return dict.__len__(self)
[docs] def items(self): return IdItemsView(self, self._id_to_obj_map)
[docs] def keys(self): return self._id_to_obj_map.values()
def __eq__(self, other): if isinstance(other, IdDict): sentinel = object() if (len(self) == len(other) and all(other.get(key, sentinel) == self[key] for key in self)): return True else: return False else: return NotImplemented def __iter__(self): return iter(self._id_to_obj_map.values()) def __repr__(self): item_reprs = [] for k, v in self.items(): item_reprs.append(f'{repr(k)}: {repr(v)}') return 'IdDict({' + ', '.join(item_reprs) + '})'
[docs] def has_key(self, obj): return self.__contains__(obj)
[docs] def update(self, other_dict): if not isinstance(other_dict, IdDict): raise ValueError('Update is only supported with other IdDicts') self.updateFromIterable(other_dict.items())
[docs] def updateFromIterable(self, iterable): for k, v in iterable: self[k] = v
[docs] @classmethod def fromIterable(cls, iterable): id_dict = cls() id_dict.updateFromIterable(iterable) return id_dict
[docs] def clear(self): self._id_to_obj_map.clear() dict.clear(self)
[docs] def copy(self): return IdDict(self)
[docs]class DefaultIdDict(IdDict): """ A dict that is both an id dict and a defaultdict. """
[docs] def __init__(self, default_factory): super().__init__() self._default_factory = default_factory
def __getitem__(self, obj): obj_id = id(obj) try: return dict.__getitem__(self, obj_id) except KeyError: default_value = self._default_factory() self[obj] = default_value return default_value
[docs] @classmethod def fromIterable(cls, iterable): raise NotImplementedError()
[docs] def setdefault(self, key, default): raise NotImplementedError()
[docs]class DefaultFactoryDictMixin: """ A mixin to use with `dict`'s that allows the dict to use a factory function similar to `defaultdict`. The key distinction here is that the factory function will be passed the key itself instead of called without any arguments. .. NOTE:: Despite the name, this mixin works with classes as well. When passed a class, the constructor will be called and passed the keys. .. WARNING:: This mixin will not work with factory functions that expect only one tuple as an argument. This is due to the way `__getitem__` packages up all keys in a single call into one tuple. """
[docs] def __init__(self, factory_func, *args, **kwargs): """ :param factory_fun: A callable to create a value from. :type factory_fun: callable """ self._factory_func = factory_func super().__init__(*args, **kwargs)
def __missing__(self, key): if isinstance(key, tuple): value = self._factory_func(*key) else: value = self._factory_func(key) self[key] = value return value
[docs]class DefaultFactoryDict(DefaultFactoryDictMixin, dict): """ A basic `dict` using the `DefaultFactoryDictMixin`. This is separated from the mixin to allow other `dict` subclasses to easily subclass `DefaultFactoryDictMixin`. Example usage:: stringified_objs = DefaultFactoryDict(str) assert 1 not in stringified_objs print(stringified_objs[1]) # '1' """