Source code for schrodinger.utils.funcgroups

"""
This module provides a mechanism for marking the methods on a class with a
decorator. These methods can then be retrieved as an ordered list from an
instance of that class. Each decorator instance creates a separate group.

To use, mix in the FuncGroupMixin to any class, and decorate the desired
methods with an instance of FuncGroupMarker.

A generic decorator, funcgroup, is provided for convenience, but it is
recommended that a separate decorator be created for each group.

Example::

    my_startup_funcs = FuncGroupMarker()

    class Foo(FuncGroupMixin):
        def __init__(self):
            for func in self.getFuncGroup(my_startup_funcs):
                func()

        @my_startup_functions(order=1)
        def initVariables(self):
            ...

        @my_startup_functions(order=2)
        def setupWorkspace(self):

"""
import collections


[docs]class FuncGroupMarker: """ This decorator marks a method on a class as belonging to a group. :param order: a numerical value that determines in what order the methods are returned when retrieved. :param order: float """
[docs] def __init__(self, label=None): """ :param label: An optional human-readable identifier to use in the repr :type label: str """ self.label = label
def __call__(self, func=None, order=None): if func is None: # func is None if decorator is called to specify order, i.e. # @funcgroup(order=2) # def foo(self): # pass return lambda func: self(func, order=order) # func is not None if the decorator is used directly, i.e. # @funcgroup # def foo(self): # pass if order is None: order = 0 func._marked_method_group = self func._marked_method_order = order return func def __repr__(self): if self.label: return f'<Func Group {self.label}>' return f'<Func Group {id(self)}>'
funcgroup = FuncGroupMarker('funcgroup')
[docs]def get_marked_func_order(func): """ Gets the order as set by the FuncGroupMarker decorator. Returns None if the function was not decorated. :param func: the function to get the order from. :return float: the order """ try: return func._marked_method_order except AttributeError: return None
[docs]class FuncData:
[docs] def __init__(self, name, order, func=None): self.name = name self.order = order self.func = func
def __repr__(self): return f'{self.name}:{self.order}'
[docs]class FuncGroupMixin: """ Use this mixin on an object so that methods decorated with a FuncGroupMarker can be retrieved on an instance of this object. """ _marked_method_groups = collections.defaultdict(dict) _default_funcgroup = funcgroup def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) cls._collectMarkedMethodData()
[docs] def __init__(self, *args, **kwargs): self._funcgroup_extra_funcs = collections.defaultdict(list) super().__init__(*args, **kwargs)
@classmethod def _collectMarkedMethodData(cls): """ Collects all the marked methods on the class and saves their names. We only collect the names at this point because the methods themselves will be bound instance methods which don't exist at the class level. """ parent_method_groups = cls._marked_method_groups cls._marked_method_groups = collections.defaultdict(dict) all_names = {} # Store all method names marked in parent classes for group, method_dict in parent_method_groups.items(): cls._marked_method_groups[group].update(method_dict) for method_name in method_dict: all_names[method_name] = group for name, attr in cls.__dict__.items(): if name in all_names: # A marked method is being overridden old_group = all_names[name] else: old_group = None try: group = attr._marked_method_group method_data = FuncData(name, attr._marked_method_order) cls._marked_method_groups[group][name] = method_data except AttributeError: group = None if (old_group is not None # A marked method was overridden and group is not old_group): # Not the same group as before old_group_methods = cls._marked_method_groups[old_group] del old_group_methods[name] def _collectMarkedMethods(self): """ Collects the instance methods corresponding to the unbound methods that are marked on the class. """ marked_method_groups = collections.defaultdict(dict) for group, method_dict in self._marked_method_groups.items(): for name, class_method_data in method_dict.items(): method = getattr(self, name) method_data = FuncData(name, class_method_data.order, method) marked_method_groups[group][name] = method_data return marked_method_groups
[docs] def getFuncGroup(self, group=None): """ Retrieve the functions belonging to the specified group. :param group: the group marker :type group: FuncGroupMarker :return: the functions in the specified group, in order :rtype: list """ if group is None: group = self._default_funcgroup method_dict = self._collectMarkedMethods()[group] extra_func_tuples = self._funcgroup_extra_funcs[group] all_func_tuples = [ (data.func, data.order) for data in method_dict.values() ] all_func_tuples.extend(extra_func_tuples) all_func_tuples.sort(key=lambda item: item[1]) all_funcs = [func for func, order in all_func_tuples] return all_funcs
[docs] def addFuncToGroup(self, func, group=None, order=None): """ Adds a function to the specified chain. Typically used for adding functions that are not methods of this object. The function may optionally be decorated with a FuncGroupMarker. If so, the default group and order will be determined by the decorator. Any group or order explicitly passed in to addFuncToGroup will take precedence over the decorator settings. :param func: the function to add :param group: the group marker. If the function is decorated with a FuncGoupMarker, that group marker will be the default. :type group: FuncGroupMarker or None :param order: the sorting order. If the function is decorated with a FuncGoupMarker, the order specified in the decorator will be the default. :type order: float or None """ if group is None: try: group = func._marked_method_group except AttributeError: group = self._default_funcgroup if order is None: try: order = func._marked_method_order except AttributeError: order = 0 self._funcgroup_extra_funcs[group].append((func, order))
[docs] def getAddedFuncs(self, group=None): if group is None: group = self._default_funcgroup return self._funcgroup_extra_funcs[group]