Source code for schrodinger.application.matsci.codeutils

"""
Module for utilities related to code maintenance

Copyright Schrodinger, LLC. All rights reserved.
"""
import ast
import importlib
import os
import pathlib
import warnings
from collections import namedtuple

import decorator
import networkx

MATSCI_MODULE_BASE = 'schrodinger.application.matsci.'
THIRD_PARTY_MODULE_DIRS = ['qexsd', 'qb_sdk', 'qeschema', 'gsas']
DESMOND_PACKAGES = [
    'analysis', 'cui', 'energygroup', 'parch', 'pfx', 'staf', 'timer', 'topo',
    'traj', 'traj_util', 'viparr', 'msys'
]
ModuleInfo = namedtuple('ModuleInfo', ['parents', 'name'])


[docs]def check_moved_variables(var_name, moved_variables): """ Check if the target variable has been moved, and if yes, post a warning about it and return the variable in the new module Raises AttributeError if a moved variable isn't found. :param str var_name: Name of the target variable :param tuple moved_variables: Tuple of tuples. Each inner tuple has a format of (module, remove_release, variables), where `module` is the new module name, `remove_release` is the release in which the link will stop working, and `variables` are the set of variables that were moved :raise AttributeError: If `var_name` is not a moved variable :rtype: Any :return: The moved variable """ for new_module_name, remove_release, variables in moved_variables: if var_name not in variables: continue # The variable was moved. Show a warning and return the new variable. if not new_module_name.startswith('schrodinger'): # Convert to full path new_module_name = MATSCI_MODULE_BASE + new_module_name msg = ( f"'{var_name}' has been moved to the '{new_module_name}' module. The" f" old usage will stop functioning in {remove_release} release.") warnings.warn(msg, FutureWarning, stacklevel=3) new_module = importlib.import_module(new_module_name) return getattr(new_module, var_name) raise AttributeError
[docs]@decorator.decorator def deprecate(func, to_remove_in=None, replacement=None, *args, **kwargs): """ Post a warning about the function being deprecated :param callable func: The function that is deprecated :param str to_remove_in: The release in which the function will be removed :param callable replacement: The function to call instead """ def name(x): # qualname includes the method's class too return f"{x.__module__}.{x.__qualname__}" msg = (f"{name(func)} is deprecated and will be " f"removed in {to_remove_in} release. ") if replacement: msg += f'Please use {name(replacement)} instead.' warnings.warn(msg, FutureWarning, stacklevel=3) return func(*args, **kwargs)
[docs]def is_python_file(path): """ Return whether the passed path is a python file :param str path: The file path :rtype: bool :return: Whether the path is a python file """ return os.path.splitext(path)[1].lower() == '.py'
[docs]def get_matsci_module_paths(): """ Return a dict of file paths and dot paths of all matsci modules, sorted :return dict: A dict mapping file paths to dot paths """ base_dir = os.path.dirname(__file__) file_path_to_dot_path = {} for root, _, files in os.walk(base_dir): if any([dir_name in root for dir_name in THIRD_PARTY_MODULE_DIRS]): continue for afile in files: if not is_python_file(afile) or afile == '__init__.py': continue abs_path = os.path.join(root, afile) rel_path = os.path.relpath(abs_path, start=base_dir) dot_path = (MATSCI_MODULE_BASE + '.'.join(pathlib.Path(rel_path).parts))[:-3] file_path_to_dot_path[abs_path] = dot_path return dict(sorted(file_path_to_dot_path.items()))
[docs]class MissingModule: """ Dummy class to return instead of missing modules. Will raise if any attribute is accessed. """ def __getattr__(self, name): raise ImportError('Unable to import desmond packages.')
[docs]def get_safe_package(name): """ Get a desmond or jaguar package without raising if the package doesn't exist :param str name: "namespace.package" where namespace is either desmond or jaguar :raises ValueError: If the namespace is not included or correct :raises ImportError: If the package name is incorrect :rtype: module or MissingModule :return: The module or a MissingModule object """ if name.startswith('desmond.'): package = name.split('.')[-1] if package in DESMOND_PACKAGES: try: return importlib.import_module( "schrodinger.application.desmond.packages." + package) except ImportError as err: return MissingModule() else: raise ImportError( f'"{package}" is not the name of a desmond package.') else: raise ValueError('The name should start with "desmond" or "jaguar".')
[docs]def get_imports(script_path, parent_required='matsci'): """ Gets all the imported module/methods in the passed script :param script_path: The script path :type script_path: str :param parent_required: Only include the methods and modules from the passed parent. :type parent_required: str :returns: iterator of module info of imported module/methods :rtype: iter """ # Parse the script with open(script_path, encoding='utf-8') as fp_script: parsed_script = ast.parse(fp_script.read(), script_path) # Get modules from the passed script for node in ast.iter_child_nodes(parsed_script): parents = [] if isinstance(node, ast.ImportFrom): if node.module: # from <parent> import <module> parents = node.module.split('.') parent_check = parents else: # from . import <module_name> parent_check = script_path.split(os.sep) elif isinstance(node, ast.Import): # import <module> # import schrodinger.application.matsci.<module> as <module> parent_check = node.names[0].name.split('.') else: continue # Check if parent module is present if parent_required and parent_required not in parent_check: continue # Check third party if set(parent_check).intersection(set(THIRD_PARTY_MODULE_DIRS)): continue for node_info in node.names: yield ModuleInfo(parents, node_info.name)
[docs]def get_matsci_module_graph(): """ Get the directed graph for matsci module imports :returns: Directed graph where each node is a matsci module and the edge represents importing of the module. :rtype: `networkx.DiGraph` """ # Get matsci modules matsci_modules = dict() for module_path in get_matsci_module_paths(): matsci_modules[module_path] = (os.path.splitext( os.path.basename(module_path))[0]) # Create directed graph module_graph = networkx.DiGraph() for filepath, name in matsci_modules.items(): for imported_module in get_imports(filepath): if imported_module.name in matsci_modules.values(): i_name = imported_module.name else: # Get the parent module incase the method was imported i_name = imported_module.parents[-1] module_graph.add_edge(i_name, name) return module_graph