Source code for schrodinger.ui.qt.network_visualizer

"""
network_visualizer.py

Description: This package is meant to help with the visualization of network-
connection data, in conjunction with network_views.py. A good example of the
type of data this is meant to visualize is at:
http://networkx.lanl.gov/index.html

The `Graph` class is meant as a wrapper for `networkx.Graph` objects, which can
then act as a model for `AbstractNetworkView` and the associated network view
classes defined in network_views.py.

Copyright Schrodinger, LLC. All rights reserved.
"""

#Author:  Pat Lorton

import copy
import math
from sys import maxsize

import networkx as nx
import numpy as np

from schrodinger.Qt.QtCore import QObject
from schrodinger.Qt.QtCore import pyqtSignal

#### Spring layout parameters ####
STARTING_TEMPERATURE = .1
ITERATIONS = 100
SCALE = 3.0  # Scale of distances to node size. Higher number is greater separation
PUSH = 1  # Multiplier for node-node repulsion
PUSHEXP = 6  # Exponential dependence of node-node repulsion

#===============================================================================
# Graph Model Classes
#===============================================================================


[docs]class GraphSignals(QObject): selectionChanged = pyqtSignal(set, object) positionChanged = pyqtSignal(set) nodesChanged = pyqtSignal(set) nodesAdded = pyqtSignal(set) nodesDeleted = pyqtSignal(set) edgesChanged = pyqtSignal(set) graphChanged = pyqtSignal() undoPointSet = pyqtSignal()
[docs]class Graph: """ A model class for an undirected graph. This wraps around the NetworkX Graph class and provides QT signals, a easier-to-use API, and access control. All persistent data should be stored in self._ggraph. Note that Graph itself cannot be pickled; Graph has Graph.signals, which is a QObject and cannot be pickled. For this reason selection information (which contains references to Graph) is not placed in self._ggraph, so that self._ggraph can be pickled. """
[docs] def __init__(self, ggraph=None, node_class=None, edge_class=None): """ Constructs a new Graph object :param ggraph: The graph underlying this graph. :type ggraph: `networkx.Graph` :param node_class: The class to represent the graph's nodes (should be subclass of `Node`) :type node_class: class :param edge_class: The class to represent the graph's edges (should be subclass of `Edge`) :type edge_class: class """ self.signals = GraphSignals() if ggraph is None: ggraph = nx.Graph() self.node_class = node_class or Node self.edge_class = edge_class or Edge self._ggraph = ggraph self.selected_nodes = set() self.selected_edges = set() self.node_objects = {} self._updateNodeMap() self.connection_validator = None self.undo_stack = [] self.max_undo_stack = 100 self.redo_stack = []
@property def ggraph(self): return self._ggraph
[docs] def update(self): """ Update any derived aspects of the graph after changes. """ return
def _updateNodeMap(self): """ Update the `node_objects` dictionary with any new nodes from the underlying ggraph. """ new_nodes = set() for gnode in self.ggraph.nodes: if str(gnode) not in self.node_objects: node = self.node_class(gnode, self) self.node_objects[node.name] = node new_nodes.add(node) return new_nodes
[docs] def setEdgeValidator(self, validator): """ Set an edge validator that will be run when adding edges between nodes. :param validator: the validator :type validator: ConnectionValidator """ if not isinstance(validator, ConnectionValidator): raise TypeError('Validator must be a subclass of ' 'ConnectionValidator') self.connection_validator = validator
[docs] def toNetworkX(self): """ Return a copy of the underlying NetworkX graph. """ return copy.deepcopy(self._ggraph)
[docs] def getData(self, key): """ Return the requested item from the graph's data dictionary. Returns None if the key is not found. """ return self._ggraph.graph.get(key)
[docs] def setData(self, key, value, signal=True): """ Set the value of an item in the graph's data dictionary. """ self._ggraph.graph[key] = value if signal: self.signals.graphChanged.emit()
[docs] def isConnected(self): """ Checks whether the graph is connected, that is, whether every node is connected by some path to every other node. :return: Whether the graph is connected rtype: bool """ return self._ggraph and nx.is_connected(self._ggraph)
#=========================================================================== # Node methods #===========================================================================
[docs] def nodeCount(self): """ :return: the number of nodes in the graph :rtype: int """ return self.ggraph.number_of_nodes()
[docs] def getIsolates(self): """ :return: a complete set of nodes in the graph that have degree 0 :rtype: set(Node) """ return {self.getNode(gnode) for gnode in nx.isolates(self.ggraph)}
[docs] def getConnectedComponents(self, nodes=None): """ Return a set of nodes for each connected component in the graph. :param nodes: optionally, a set of nodes to filter the returned components. If provided, this method will only return components for which at least one node is in `nodes` :type nodes: set(Node) or NoneType :return: a generater over each connected component in the graph :rtype: typing.Generator[set[Node], None, None] """ for gnodes in nx.connected_components(self.ggraph): component_nodes = {self.getNode(gnode) for gnode in gnodes} if nodes is None or nodes.intersection(component_nodes): yield component_nodes
[docs] def getNodeConnectedComponent(self, node): """ Return a set of nodes that are part of the same connected component as `node`. :param node: a node :type node: Node :return: a set of nodes connected to `node` through any path of edges :rtype: set(Node) """ return { self.getNode(gnode) for gnode in nx.node_connected_component(self.ggraph, node.gnode) }
def _getNodeInstances(self, objects): """ Given a list of objects, return only the `Node` instances among them. :param objects: a list of objects :type objects: list(object) :return: the set of node instances among the supplied list of objects :rtype: set(Node) """ return {obj for obj in objects if isinstance(obj, self.node_class)} def _getEdgeInstances(self, objects): """ Given a list of objects, return only the `Edge` instances among them. :param objects: a list of objects :type objects: list(object) :return: the set of edge instances among the supplied list of objects :rtype: set(Edge) """ return {obj for obj in objects if isinstance(obj, self.edge_class)}
[docs] def getNode(self, node_key): """ Retrieve a node via its name. Retrieved nodes are cached, so getting the same Node again will return the same instance. Returns None if no matching Node exists. :param node_key: a node, gnode, or string that corresponds to the desired node :type node_key: object :return: a node if found, else `None` :rtype: Node or NoneType """ if isinstance(node_key, self.node_class): return node_key return self.node_objects.get(str(node_key))
[docs] def getNodes(self, node_keys=None): """ Retrieve a set of nodes optionally indicated by a list of keys. If none is provided, return all nodes. :param node_keys: optionally, a list of nodes, gnodes, or strings that correspond to the desired nodes :type node_keys: list(object) or NoneType :return: a set of nodes :rtype: set(Node) """ if node_keys is None: return set(self.node_objects.values()) nodes = set() for node_key in node_keys: node = self.getNode(node_key) if node: nodes.add(node) return nodes
[docs] def getNeighbors(self, node): """ Return a set of all nodes connected to a specified node :param node: center node :type node: Node :return: neighboring nodes :rtype: set of Node """ gnodes = self._ggraph.neighbors(node.gnode) return self.getNodes(gnodes)
[docs] def addNodes(self, nodes, signal=True): """ Add a list of nodes to this graph. The `nodes` argument can either be a list of `Node` objects or a list of hashable objects that can be used as new gnodes. Note that any time a new gnode is created for use in this graph, its string representation must be unique among the other nodes in this graph: nodes are keyed in the `node_objects` dictionary by the string representation of their corresponding gnode. :param nodes: list of gnodes or nodes :type nodes: `list(object)` or `list(Node)` :param signal: whether the `addNodes` signal should be emitted when done :type signal: bool :return: a set of added nodes :rtype: set(Node) """ for node in nodes: if not isinstance(node, Node): if str(node) in self.node_objects: msg = ('A node with the same string representation ("{0}")' ' already exists in this graph. New nodes must have' ' a unique string representation.').format(node) raise ValueError(msg) new_node = self.node_class(node, self) else: new_node = node if new_node.gnode in self.ggraph.nodes: msg = 'Node %s already exists in graph.' % new_node.gnode raise ValueError(msg) self._ggraph.add_node(new_node.gnode, **new_node.gdata()) new_nodes = self._updateNodeMap() if signal and new_nodes: self.signals.nodesAdded.emit(new_nodes) return new_nodes
[docs] def addNode(self, node, signal=True): """ Convenience method for adding a single node to the graph. See `addNodes()` for full documentation. :param node: gnode or node :type node: hashable or `Node` :param signal: whether the `addNodes` signal should be emitted when done :type signal: bool :return: the added node :rtype: Node """ new_nodes = self.addNodes([node], signal=signal) return new_nodes.pop()
[docs] def removeNodes(self, nodes, signal=True): """ Remove specified nodes from the graph and optionally emit a signal. :param nodes: a list of nodes to be removed :type nodes: list(Node) :param signal: whether to emit a `nodesDeleted` signal when done :type signal: bool """ self._ggraph.remove_nodes_from([n.gnode for n in nodes]) for node in nodes: del self.node_objects[node.name] if signal: self.signals.nodesDeleted.emit(set(nodes))
[docs] def removeNode(self, node, signal=True): """ Convenience function for removing a single node. See `removeNode()` for full documentation. :param node: a gnode or node to remove :type node: `object` or `Node` :param signal: whether to emit a `nodesDeleted` signal when done :type signal: bool """ self.removeNodes([node], signal=signal)
[docs] def setMultipleNodePos(self, pos_dict, signal=True): """ Set the positions of nodes from a dictionary. :param pos_dict: A dictionary mapping nodes to (x,y) tuples. :type pos_dict: dict {Node : (int, int)} """ changednodes = set() for node, pos in pos_dict.items(): node.setPos(pos[0], pos[1], False) changednodes.add(node) if signal: self.signals.positionChanged.emit(changednodes)
#=========================================================================== # Edge methods #===========================================================================
[docs] def edgeCount(self): """ :return: the number of edges in the graph :rtype: int """ return self.ggraph.number_of_edges()
[docs] def hasEdge(self, node1, node2): """ Return whether there is an edge between the supplied nodes. :param node1: a node from this graph :type node1: Node :param node2: a node from this graph :type node2: Node :return: whether there exists an edge between the two supplied nodes :rtype: bool """ return self._ggraph.has_edge(node1.gnode, node2.gnode)
[docs] def getGEdge(self, node0, node1): """ Return the underlying gedge object corresponding to two supplied nodes. This can be overwritten in subclasses, but the returned class should define a consistent edge ordering that is independent of the order of the supplied node parameters. :param node0: a node :type node0: Node :param node1: a node :type node1: Node :return: the underlying gedge between the two nodes, if it exists :rtype: tuple(networkx.Node) or NoneType """ if not self.hasEdge(node0, node1): return None return tuple(sorted([node0.gnode, node1.gnode]))
[docs] def getEdge(self, node0, node1): """ Given two nodes, return the corresponding edge if it exists. :param node0: a node :type node0: Node :param node1: a node :type node1: Node :return: the edge connecting the two nodes if it exists :rtype: Edge or NoneType """ gedge = self.getGEdge(node0, node1) if gedge: return self.edge_class(gedge, self)
[docs] def getEdges(self, nodes=None): """ Return all edges connected to a node or set of nodes. If no node is specified, all the edges in the graph are returned. :param nodes: optionally a node or iterable of nodes :type nodes: `iterable(Node)`, `Node`, or `None` :return: a set of edges connected to at least one of the supplied nodes, or a set of all edges if `nodes` is not specified :rtype: set(Edge) """ if nodes is None: gnodes = None else: if not hasattr(nodes, '__iter__'): nodes = [nodes] gnodes = [node.gnode for node in nodes] edges = set() for gnode1, gnode2 in self.ggraph.edges(gnodes): node1, node2 = self.getNode(gnode1), self.getNode(gnode2) edges.add(self.getEdge(node1, node2)) return edges
[docs] def addEdges(self, edge_tuples, signal=True): """ Add edges to graph. :param edge_tuples: list of tuples indicating the edges to add, containing two gnodes or nodes and an edge attribute dictionary (or `None`) :type edge_tuples: list(tuple(Node, Node, dict)) or list(tuple(Node, Node, None)) :param signal: whether `edgesChanged` signal should be emitted when done :type signal: bool """ new_edges = set() for node1, node2, data_dict in edge_tuples: if data_dict is None: data_dict = {} if self.hasEdge(node1, node2): msg = 'Edge {} already exists in graph.'.format((node1, node2)) raise ValueError(msg) self._ggraph.add_edge(node1.gnode, node2.gnode, **data_dict) new_edges.add(self.getEdge(node1, node2)) if signal: self.signals.edgesChanged.emit(new_edges)
[docs] def addEdge(self, node1, node2, signal=True, data=None): """ Convenience function to add a single edge to the graph given two nodes. The order of the nodes does not matter. :param node1: a gnode or node connected by the edge :type node1: `object` or `Node` :param node2: a gnode or node connected by the edge :type node2: `object` or `Node` :param signal: whether `edgesChanged` signal should be emitted when done :type signal: bool """ self.addEdges([(node1, node2, data)], signal=signal)
[docs] def removeEdges(self, edges, signal=True): """ Removes specified edges from the graph. :param edges: a list of edges :type edges: list(Edge) :param signal: whether `edgesChanged` signal should be emitted when done :type signal: bool """ for edge in edges: node1, node2 = edge if not self.hasEdge(node1, node2): raise ValueError('Edge not found between %s and %s' % (node1, node2)) self._ggraph.remove_edge(node1.gnode, node2.gnode) if signal: self.signals.edgesChanged.emit(set(edges))
[docs] def removeEdge(self, edge, signal=True): """ Convenience function to remove a single edge from the graph. :param edge: an edge :type edge: Edge :param signal: whether `edgesChanged` signal should be emitted when done :type signal: bool """ self.removeEdges([edge], signal=signal)
[docs] def getEdgeApproval(self, node1, node2): """ Test whether a new edge can be added between two nodes. Doesn't actually add an edge, just returns whether it is allowable to add. """ if self.hasEdge(node1, node2): return False, "This connection already exists" if node1 == node2: return False, "Can't connect a node to itself." if self.connection_validator: return self.connection_validator.validate(node1, node2) return True, "No Problem"
#=========================================================================== # Selection #===========================================================================
[docs] def selectedNodes(self): """ Return the currently selected nodes. :rtype: set of Nodes """ return self.selected_nodes
[docs] def selectedEdges(self): """ :return: the set of selected edges :rtype: set(Edge) """ return self.selected_edges
[docs] def setSelectedObjs(self, objs, source=None, signal=True): """ Specify the current selection. :param objs: a list of objects (nodes or edges) to be selected :type objs: list(Node or Edge) :param source: the class instance calling this method (used to avoid infinite recursion when updating selection state) :type source: object :param signal: whether to emit a signal when changing selection state :type signal: bool """ nodes = self._getNodeInstances(objs) edges = self._getEdgeInstances(objs) if set.symmetric_difference(nodes, self.selectedNodes()): self.selected_nodes = nodes if set.symmetric_difference(edges, self.selectedEdges()): self.selected_edges = edges items = nodes.union(edges) if signal: self.signals.selectionChanged.emit(items, source)
#=========================================================================== # Layout methods #===========================================================================
[docs] def springLayout(self, signal=True): """ Performs a spring layout on the current graph. """ node_coord_map = self._getSpringLayoutCoords(iterations=ITERATIONS, weight_attr=None, scale=SCALE) self.setMultipleNodePos(node_coord_map, signal)
def _getSpringLayoutCoords(self, dim=2, node_pos_map=None, fixed_nodes=None, iterations=50, weight_attr='weight', scale=1): """ Calculate and return a dictionary mapping nodes to optimally-computed Cartesian coordinates for each node. Convenience method that wraps `spring_layout()`. :param dim: number of dimensions of the layout :type dim: int :param node_pos_map: optionally, initial positions for nodes; otherwise, use random initial positions :type node_pos_map: dict(Node, tuple(float)) :param fixed_nodes: optionally, a list of nodes to keep fixed at their initial positions :type fixed_nodes: list(Node) :param iterations: number of iterations of spring-force relaxation :type iterations: int :param weight_attr: the edge attribute that holds the numerical value used for the edge weight. If None, then all edge weights are 1. :type weight_attr: str or None :param scale: scale factor for positions :type scale: float :return: a dictionary mapping nodes to their calculated positions :rtype: dict(Node, tuple(float)) """ if fixed_nodes: fixed_gnodes = [n.gnode for n in fixed_nodes] else: fixed_gnodes = None if node_pos_map: gnode_pos_map = {n.gnode: pos for n, pos in node_pos_map.items()} else: gnode_pos_map = None node_coords = spring_layout(self.ggraph, dim=dim, pos=gnode_pos_map, fixed=fixed_gnodes, iterations=iterations, weight=weight_attr, scale=scale) node_coord_map = {} for name, coords in node_coords.items(): node = self.getNode(name) node_coord_map[node] = coords return node_coord_map
[docs] def minCrossingSpringLayout(self, num_iterations=100, fixed_nodes=None, fraction=1.0): """ Perform multiple spring layouts and keep the one with the fewest edge intersections, keeping the original positions if the layout could not be improved. :param num_iterations: number of spring layouts to try :type num_iterations: int :param fixed: nodes for which the position should be fixed :type fixed: iterable of Node :param signal: whether to emit the positionChanged signal :type signal: bool :param fraction: stop iterating if no reduction in crossings is found within this fraction of num_iterations :type fraction: float """ min_crossings = maxsize best_pos_map = None edges = self.getEdges() fixed_pos_map = None if fixed_nodes is not None: fixed_pos_map = {node: node.pos() for node in fixed_nodes} initial_pos_map = fixed_pos_map # if I have positions, take those as the ones to improve on if self.hasPositions(): initial_pos_map = {node: node.pos() for node in self.getNodes()} _, min_crossings = _has_fewer_crossings(edges, initial_pos_map, maxsize) best_pos_map = initial_pos_map if initial_pos_map: # unscale coordinates to preserve location self._scaleNodeCoords(initial_pos_map, reverse=True) new_pos_map = initial_pos_map max_not_better_iters = min(num_iterations, fraction * num_iterations) not_better_iters = 0 for i in range(num_iterations): not_better_iters += 1 if min_crossings == 0: break if not_better_iters > max_not_better_iters: break new_pos_map = self._getSpringLayoutCoords(iterations=ITERATIONS, node_pos_map=new_pos_map, fixed_nodes=fixed_nodes, weight_attr=None, scale=SCALE) is_better, crossings = _has_fewer_crossings(edges, new_pos_map, min_crossings) if is_better: not_better_iters = 0 min_crossings = crossings best_pos_map = new_pos_map else: new_pos_map = fixed_pos_map if initial_pos_map is best_pos_map: return self._scaleNodeCoords(best_pos_map) self.setMultipleNodePos(best_pos_map)
def _scaleNodeCoords(self, pos_dict, reverse=False): """ Scales the positions in the pos_dict dictionary by factor or if reverse is True 1/factor, where factor = 0.5 x sqrt(NNodes)/2. Through manual testing 0.5 was determined to be a good multiplier. :param pos_dict: A dictionary mapping nodes or node names to (x,y) tuples. :type pos_dict: dict {Node : (int, int)} :param reverse: Whether to reverse the scaling :type revers: bool """ num_nodes = len(self._ggraph.nodes) or 1 scale = 0.5 * math.sqrt(num_nodes) if reverse: scale = 1.0 / scale self._scaleDictPositions(pos_dict, scale) @staticmethod def _scaleDictPositions(pos_dict, factor): """ Multiplies the positions in {node: (x_pos, y_pos)} dictionary by factor. :param pos_dict: A dictionary mapping nodes to (x,y) tuples. :type pos_dict: dict {Node : (int, int)} :param factor: multiplication factor for positions :type factor: float """ for node, xy in pos_dict.items(): scaled_x_pos = xy[0] * factor scaled_y_pos = xy[1] * factor pos_dict[node] = [scaled_x_pos, scaled_y_pos] return pos_dict
[docs] def hasPositions(self, accept_partial=False): """ Determines whether the nodes in this graph have x-y coordinates. :param accept_partial: if set to True, the method will check whether at least one node has coordinates. Otherwise it requires that all nodes have coordinates. :type accept_partial: bool """ fully_positioned = True for node in self.getNodes(): if node.pos() is None: fully_positioned = False else: if accept_partial: return True return fully_positioned
#=========================================================================== # Undo/redo #===========================================================================
[docs] def getState(self): """ Get the current state of the Graph """ ggraph = copy.deepcopy(self._ggraph) node_objects = self.node_objects.copy() return ggraph, node_objects
[docs] def setState(self, state): """ Set the current state of the Graph """ ggraph, node_objects = state self._ggraph = ggraph self.node_objects = node_objects self.selected_nodes = set() self.selected_edges = set() self.signals.graphChanged.emit()
[docs] def setUndoPoint(self, signal=True): """ Store the current state to the undo stack. Also wipes out the redo stack. """ self.undo_stack.append(self.getState()) while len(self.undo_stack) > self.max_undo_stack: self.undo_stack.pop(0) self.redo_stack = [] if signal: self.signals.undoPointSet.emit()
[docs] def undo(self): """ Revert to the last state on the undo stack. """ if not self.undo_stack: return self.redo_stack.append(self.getState()) while len(self.redo_stack) > self.max_undo_stack: self.redo_stack.pop(0) state = self.undo_stack.pop() self.setState(state)
[docs] def redo(self): """ Undo the undo """ if not self.redo_stack: return self.undo_stack.append(self.getState()) state = self.redo_stack.pop() self.setState(state)
[docs] def clearUndoHistory(self): """ Clears both undo and redo stacks """ self.undo_stack = [] self.redo_stack = []
[docs] def merge(self, g): """ Merge data from another graph into this graph. Nodes with duplicate names will be considered to be the same ligand. :param g: graph from which data is being merged. :type g: `Graph` """ for edge in g.getEdges(): data_dict = edge.data() n1, n2 = edge if 'direction' not in data_dict: hex1, hex2 = n1.name, n2.name d = (hex1, hex2) if hex1 < hex2 else (hex2, hex1) edge.setData('direction', d) self._ggraph.add_nodes_from(g._ggraph.nodes(data=True)) self._ggraph.add_edges_from(g._ggraph.edges(data=True))
[docs] def deleteSelectedItems(self, include_edges=True, include_nodes=True): """ Delete selected nodes and/or selected edges. :param include_edges: whether selected edges should be deleted :type include_edges: bool :param include_nodes: whether selected nodes should be deleted :type include_nodes: bool """ nodes = self.selectedNodes() if include_nodes else set() edges = self.selectedEdges() if include_edges else set() if not nodes and not edges: return self.setUndoPoint() self.setSelectedObjs([]) self.deleteItems(nodes, edges)
[docs] def deleteItems(self, nodes=None, edges=None): """ Delete specified nodes and edges from the FEP map. :param nodes: nodes to delete :type nodes: Set[Node] :param edges: edges to delete :type edges: Set[Tuple[Node, Node]] """ nodes = nodes or set() edges = edges or set() connected_edges = set(self.getEdges(nodes)) edges = edges.union(connected_edges) if edges: self.removeEdges(edges) if nodes: self.removeNodes(nodes) if edges or nodes: self.update()
[docs]class Node: """ Model class for Node. Wraps the NetworkX Graph.node dictionary. """ x_key = 'storedX' y_key = 'storedY'
[docs] def __init__(self, name, graph=None): """ Construct a Node object. Most of the time, this will be constructed around an existing NetworkX node (i.e. an entry in the networkx.Graph.node dict). If a graph is specified, a node of the same name must exist in the graph, or a ValueError will result. QT signals will only be emitted if a graph is specified. :param name: a unique identifier for this node :type name: hashable :param graph: the graph object to which this node belongs :type graph: `Graph` :ivar _gnode: the underlying graph node that this node wraps. In this class, we use the node name as the graph node, but any hashable object can be used. :ivar _gdata: dictionary that stores data belonging to the underlying graph node. """ gdata = {} if graph: try: gdata = graph.ggraph.nodes.get(name, {}) except KeyError: raise ValueError('Node %s not found in graph.' % name) self._gnode = name self._gdata = gdata self.graph = graph
@property def gnode(self): """ Return the underlying graph node object wrapped by this `Node` instance (not the data dictionary `_gdata`). """ return self._gnode @property def name(self): """ Return unique string associated with this node. Convert to string for subclasses which do not necessarily use strings as graph nodes. """ return str(self.gnode) #=========================================================================== # Positioning #===========================================================================
[docs] def x(self): return self._gdata.get(self.x_key, None)
[docs] def y(self): return self._gdata.get(self.y_key, None)
[docs] def pos(self): """ Returns the Node's current position coordinates. Returns None if there are no coordinates. :rtype: tuple (float, float) """ pos = (self.x(), self.y()) if None in pos: return None return pos
[docs] def setX(self, x, signal=True): if self.x() == x: return self._gdata[self.x_key] = x if signal and self.graph: self.graph.signals.positionChanged.emit({self})
[docs] def setY(self, y, signal=True): if self.y() == y: return self._gdata[self.y_key] = y if signal and self.graph: self.graph.signals.positionChanged.emit({self})
[docs] def setPos(self, x, y, signal=True): """ Set the node's position coordinates :param x: x coordinate :type x: float :param y: y coordinate :type y: float """ if self.x() == x and self.y() == y: return self.setX(x, False) self.setY(y, False) if signal and self.graph: self.graph.signals.positionChanged.emit({self})
#=========================================================================== # General node properties #===========================================================================
[docs] def gdata(self): """ Directly access the node data dictionary. Use this object carefully, as directly altering its contents can lead to internal inconsistencies. This may be wrapped to restrict access. """ return self._gdata
[docs] def getData(self, key): """ Return the requested item from the node's data dictionary. Returns None if the key is not found. """ return self._gdata.get(key, None)
[docs] def setData(self, key, value, signal=True): """ Set the value of an item in the node's data dictionary. """ self._gdata[key] = value if signal and self.graph: self.graph.signals.nodesChanged.emit({self})
@property def degree(self): """ :return: the degree (number of edges) of the node :rtype: int """ return self.graph.ggraph.degree(self.gnode) def __repr__(self): return '<Node("%s")>' % self.name def __str__(self): return self.__repr__() def __eq__(self, rhs): try: return id(self.graph) == id(rhs.graph) and self.name == rhs.name except AttributeError: return False def __ne__(self, rhs): return not self == rhs def __hash__(self): return hash((id(self.graph), self.name))
[docs]class Edge:
[docs] def __init__(self, gedge, graph): """ :param gedge: the underlying edge object wrapped by this object :type gedge: object :param graph: the graph object to which this edge belongs :type graph: Graph """ self._gedge = gedge self._graph = graph
@property def gedge(self): """ :return: the underlying edge object wrapped by this object :rtype: fep.graph.Edge """ return self._gedge @property def graph(self): """ :return: the graph to which this edge belongs :rtype: Graph """ return self._graph @property def nodes(self): """ :return: the nodes connected by this edge in a consistent order, as determined by the underlying graph edge :rtype: tuple(Node, Node) """ return tuple(self.graph.getNode(gnode) for gnode in self.gedge)
[docs] def data(self): """ :return: the data dictionary associated with this edge :rtype: dict(str, object) """ ggraph, gedge = self.graph.ggraph, self.gedge return dict(ggraph[gedge[0]][gedge[1]])
[docs] def getData(self, key): """ Return the requested item from the edge's data dictionary. Returns None if the key is not found. :param key: the data item key :type key: str :return: the value stored under the specified key in the edge's data dictionary, or `None` if it is not found :rtype: object """ data_dict = self.data() return data_dict.get(key)
[docs] def setData(self, key, value, signal=True): """ Set the specified item in the edge's data dictionary. :param key: the data item key :type key: str :param value: the value to set for the data item :type value: object """ ggraph, gedge = self.graph.ggraph, self.gedge data_dict = ggraph[gedge[0]][gedge[1]] old_value = data_dict.get(key) data_dict[key] = value if signal and old_value != value: self.graph.signals.edgesChanged.emit({self})
@property def name(self): """ :return: the name of the edge, a composite of the connected node names :rtype: str """ node0, node1 = self.nodes name0 = 'None' if node0 is None else node0.name name1 = 'None' if node1 is None else node1.name return f'"{name0}" - "{name1}"' def __getitem__(self, idx): """ Return a node connected by this edge. Only accepts indices 0 and 1. :param idx: node index :type idx: `int` :return: node corresponding to supplied index :rtype: `LigandNode` """ return self.nodes[idx] def __eq__(self, rhs): try: return self.graph == rhs.graph and self.gedge == rhs.gedge except AttributeError: return False def __ne__(self, rhs): return not self == rhs def __hash__(self): return hash((self.graph, self.gedge)) def __str__(self): return f'<{self.__class__.__name__}({self.name})>' def __repr__(self): return self.__str__()
[docs]class ConnectionValidator: """ Create a subclass of this and assign it using NetworkViewer.setConnectionValidator( ) to do extra work making sure node's are compatible to connect. val1 and val2 are node1.val and node2.val """
[docs] def __init__(self): self.first_node = None
[docs] def validate(self, node1, node2): return True, "No problem"
[docs] def firstNode(self): return self.first_node
[docs] def setFirstNode(self, node): self.first_node = node
[docs] def validateSecondVal(self, val): if self.firstNode(): return self.validate(self.first_node.val, val)
#=============================================================================== # Network View Classes #===============================================================================
[docs]class AbstractNetworkView: """ A base class for views on Graph models. Use setModel to replace the model object. Signals from the model are automatically connected to appropriate synchronization slots. The abstract view does not provide any built-in support for effecting changes back into the model (ex. deleting nodes, changing selection). Any such operations should be implemented in the subclass by making calls directly to the model. These changes will then be automatically synchronized forward to all views. self.nodes is a dictionary mapping model node objects to view node objects. self.edges is a dictionary mapping pairs of model node objects to view edge objects. There is no such thing as a edge model object. Note that all references to the word node and edge in method names refer to view objects. For example, makeNode() will make a view node, addEdge() will add an edge view object to the view. :cvar MODEL_CLASS: an instance of this class will be created as the default model when `setModel` :vartype MODEL_CLASS: `Graph` or subclass of `Graph` :ivar _sync_with_model: whether to automatically synchronize this view (and its subviews) with the model :vartype _sync_with_model: bool """ MODEL_CLASS = Graph
[docs] def __init__(self): self.model = None self.nodes = {} self.edges = {} self.skip_selectionChanged = False self._subviews = set() self._sync_with_model = True
#=========================================================================== # Model-View Connections #===========================================================================
[docs] def syncAll(self): """ Synchronize the full model and selection state. """ model = self.model self.syncModel() selection = model.selected_nodes.union(model.selected_edges) self.syncSelection(selection, model)
[docs] def syncRecursive(self): """ Synchronize the full model and selection state on this view and all subviews. """ self.syncAll() for subview in self._subviews: subview.syncRecursive()
[docs] def setModelSyncEnabled(self, enable): """ Enable or disable automatic synchronization with the model for this view and all subviews. """ for subview in self._subviews: subview.setModelSyncEnabled(enable) if self._sync_with_model == enable: return self._sync_with_model = enable if enable: self._connectSignals() self.syncAll() else: self._disconnectSignals()
[docs] def setModel(self, model): """ Set the model for this view and synchronize to it. Any subviews will have the model set on them as well. :param model: the graph model :type model: Graph """ if model is None: model = self.MODEL_CLASS() if self._sync_with_model: self._disconnectSignals() self.model = model for subview in self._subviews: subview.setModel(model) if self._sync_with_model: self._connectSignals() self.syncAll()
def _connectSignals(self): """ If a model is defined, connect all signal/slot pairs. """ if self.model: for signal, slot in self.getSignalsAndSlots(self.model): signal.connect(slot) def _disconnectSignals(self): """ If a model is defined, disconnect all signal/slot pairs. """ if self.model: for signal, slot in self.getSignalsAndSlots(self.model): signal.disconnect(slot)
[docs] def getSignalsAndSlots(self, model): """ Get a list of signal/slot pairs for a model. This list will be used when setting a new model to disconnect the old model signals from their slots and connect the new model's signals to those slots. Override this method to modify or extend signals/slots in derived classes. :param model: the graph model :type model: Graph """ signals = model.signals ss_list = [ (signals.graphChanged, self.syncModel), (signals.nodesAdded, self.syncNodesAdded), (signals.nodesDeleted, self.syncNodesDeleted), (signals.nodesChanged, self.syncNodesChanged), (signals.edgesChanged, self.syncModel), (signals.selectionChanged, self.syncSelection), ] return ss_list
[docs] def addSubview(self, subview): """ Add a subview to this view. A subview is another AbstractNetworkView that should always have the same model as its parent view (this view). Adding will automatically set its model to the current model. Changing the model on this view will result in all its subviews getting the new model set :param subview: the new subview to add to this view :type subview: AbstractNetworkView """ self._subviews.add(subview) subview.setModel(self.model)
[docs] def removeSubview(self, subview): """ Removes the specified subview. The subview is not deleted or altered, and the model remains set. :param subview: :type subview: """ self._subviews.remove(subview)
#=========================================================================== # Model-View Synchronization #===========================================================================
[docs] def syncModel(self): self.syncNodes() self.syncEdges()
[docs] def syncNodes(self): graph = self.model nodeset = graph.getNodes() delnodes = set(self.nodes).difference(nodeset) self.syncNodesAdded(nodeset) self.syncNodesChanged(nodeset) self.syncNodesDeleted(delnodes)
[docs] def syncNodesDeleted(self, nodes): self._removeNodes(nodes) self.syncEdges()
[docs] def syncNodesAdded(self, nodes): new_nodes = nodes.difference(set(self.nodes)) self._addNodes(new_nodes) self.syncEdges()
[docs] def syncNodesChanged(self, nodes): self.updateNodes(nodes) if self.edges: edges = self.model.getEdges(nodes) self.updateEdges(edges)
[docs] def syncEdges(self): model_edges = set(self.model.getEdges()) known_edges = set(self.edges) del_edges = known_edges.difference(model_edges) self._removeEdges(del_edges) add_edges = model_edges.difference(known_edges) self._addEdges(add_edges) up_edges = model_edges.intersection(known_edges) self.updateEdges(up_edges)
[docs] def syncSelection(self, selection, source): if source == self: return selected_view_objects = [] for model_obj in selection: if isinstance(model_obj, Node): viewnode = self.nodes.get(model_obj) if viewnode: selected_view_objects.append(viewnode) elif isinstance(model_obj, Edge): viewedge = self.getEdge(model_obj) if viewedge: selected_view_objects.append(viewedge) self.skip_selectionChanged = True self.selectItems(selected_view_objects) self.skip_selectionChanged = False
#=========================================================================== # Node operations #=========================================================================== def _addNodes(self, nodes): node_map = self.makeNodes(nodes) self.nodes.update(node_map) self.addNodes(set(node_map.values())) def _removeNodes(self, nodes): viewnodes = [self.getNode(node) for node in nodes] self.removeNodes(viewnodes) for node in nodes: self.nodes.pop(node)
[docs] def makeNodes(self, nodes): """ Create new view nodes and return a dictionary mapping supplied model nodes to corresponding view nodes. Do not add new view nodes to the view. By default this method returns an "identity dictionary" that maps nodes to themselves. Subclasses should override this method to implement their own view nodes. :param nodes: model nodes :type nodes: list(Node) :return: a dictionary mapping supplied nodes to view nodes :rtype: dict(Node, object) """ return {node: node for node in nodes}
[docs] def makeNode(self, node): """ Convenience method for calling `makeNodes()` with a single node. Rather than returning a dictionary mapping nodes to view nodes, returns the view node corresponding to the supplied node. :param node: the model node :type node: Node :return: the view node :rtype: object """ node_map = self.makeNodes([node]) return node_map.get(node)
[docs] def addNode(self, viewnode): """ A convenience function for calling `addNodes()` for a single node. :param viewnode: a view node :type viewnode: object """ self.addNodes([viewnode])
[docs] def removeNode(self, viewnode): """ Convenience method for calling `removeNode()` for a single node. :param viewnode: a view node :type viewnode: object """ self.removeNodes([viewnode])
[docs] def updateNode(self, node): """ Convenience method for calling `updateNodes()` for a single node. :param node: the model node to update to :type node: Node """ self.updateNodes([node])
[docs] def getModelNodes(self, node_keys=None): """ Retrieve a set of model nodes optionally indicated by a list of keys. If none is provided, return all nodes. :param node_keys: optionally, a list of nodes, gnodes, or strings that correspond to the desired model nodes :type node_keys: list(object) or NoneType :return: a set of nodes :rtype: set(Node) """ nodes_in_view = set(list(self.nodes)) nodes_in_model = self.model.getNodes(node_keys) return nodes_in_view.intersection(nodes_in_model)
[docs] def getNode(self, node): """ :param node: a model node :type node: Node :return: corresponding view node, if available :rtype: `object` or `None` """ return self.nodes.get(node, None)
#=========================================================================== # Edge operations #=========================================================================== def _addEdges(self, edges): for edge in edges: view_edge = self.getEdge(edge) if view_edge is not None: msg = f'A view edge already exists for {edge}.' raise ValueError(msg) edge_map = self.makeEdges(edges) for edge, view_edge in edge_map.items(): self.edges[edge] = view_edge self.addEdges(list(edge_map.values())) def _removeEdges(self, edges): view_edges = [self.getEdge(edge) for edge in edges] self.removeEdges(view_edges) for edge in edges: self.edges.pop(edge)
[docs] def makeEdges(self, edges): """ Given a list of model edges, return a dictionary mapping them to corresponding view edges. Does not add view edges to the view. By default this method returns an identity dictionary, mapping model edges to themselves. Subclasses should override this method if they want to implement their own view edges. :param edges: a list model nodes :type nodepairs: list(Edge) :return: a dictionary mapping model edges to view edges :rtype: dict(Edge, object) """ return {edge: edge for edge in edges}
[docs] def makeEdge(self, edge): """ Convenience method for calling `makeEdges()` for a single edge. Rather than return a dictionary mapping model edges to view edges, returns a singe view edge. Does not add a view edge to the view. :param edge: a model edge :type edge: Edge :return: a view edge :rtype: object """ edge_map = self.makeEdges([edge]) return edge_map.get(edge)
[docs] def addEdge(self, viewedge): """ Convenience method for calling `addEdges()` for a single edge. :param viewedge: the view edge to add to the view :type viewedge: object """ self.addEdges([viewedge])
[docs] def removeEdge(self, viewedge): """ Convenience method for calling `removeEdges()` for a single edge. :param viewedge: the view edge to remove from the view :type viewedge: object """ self.removeEdges([viewedge])
[docs] def updateEdge(self, edge): """ A convenience method for calling `updateEdges()` for a single edge. :param edge: the model edge corresponding to the view edge to update :type edge: Edge """ self.updateEdges([edge])
[docs] def getModelEdges(self, nodes=None): """ Return all model edges connected to a model node or set of model nodes. If no node is specified, all the edges in the graph are returned. This method acts like `Graph.getEdges()`, but it filters for model edges that are available in this view. :param nodes: optionally a node or list of nodes :type nodes: `list(Node)`, `Node`, or `None` :return: a list of model edges :rtype: list(Edge) """ model_edges = set(self.model.getEdges(nodes=nodes)) return list(model_edges.intersection(set(self.edges)))
[docs] def getEdge(self, edge): """ Return the view edge corresponding to the supplied model edge. :param edge: a model edge :type edge: Edge :return: the corresponding view edge if available :rtype: object or None """ return self.edges.get(edge)
[docs] def getEdges(self, nodes=None): """ Return a list of view edges, filtering the list so that the edges are connected to the optionally-supplied node or iterable of nodes. :param nodes: a node or iterable of nodes :type nodes: iterable[Node] or Node or NoneType :return: list of view edges :rtype: list[NetworkEdge or NoneType] """ return [self.getEdge(edge) for edge in self.model.getEdges(nodes)]
#=========================================================================== # Pure virtual methods #===========================================================================
[docs] def addNodes(self, viewnodes): """ Takes view nodes and adds them to the view if that makes sense (eg. add graphics items to scene, add rows to table, etc.) It should not add the view node to `self.nodes`; that is handled in `_addNodes()`. :param viewnodes: view nodes to add to the view :type viewnodes: list(object) """
[docs] def removeNodes(self, viewnodes): """ Removes view nodes from the view if that makes sense (eg. remove graphics items from scene, remove table rows, etc.) It should not remove view nodes from `self.nodes`; that is handled in `_removeNodes()`. :param viewnodes: a list of view nodes :type viewnodes: list(object) """
[docs] def updateNodes(self, nodes): """ Performs any operations necessary to update the view to the current model state. Note that this method takes model nodes, not view nodes. :param nodes: model nodes which must have their views updated :type nodes: list(Node) """
[docs] def addEdges(self, viewedges): """ Adds view edges to the view. Does not add view edges to `self.edges`. :param viewedges: view edges to add to the view :type viewedges: list(object) """
[docs] def removeEdges(self, viewedges): """ Removes view edges from the view. Does not remove view edges from `self.edges`. :param viewedges: view edges to remove from the view :type viewedges: list(object) """
[docs] def updateEdges(self, edges): """ Performs any operations necessary to update the view to the current model state. :param edges: a list of model edges corresponding to view edges that should be updated :type edges: list(Edge) """
[docs] def selectItems(self, selected_view_objects): """ Selects view objects in the view. Currently only view nodes will be requested, but may be expanded to allow a combination of nodes and edges to be selected. :param selected_view_objects: a list of view objects to be selected :type selected_view_objects: list(object) """
#=============================================================================== # Layout calculations #=============================================================================== # # line segment intersection using vectors # see Computer Graphics by F.S. Hill #
[docs]def perp(a): b = np.empty_like(a) b[0] = -a[1] b[1] = a[0] return b
# line segment a given by endpoints a1, a2 # line segment b given by endpoints b1, b2 # return _eps = 1e-8
[docs]def seg_intersect(a1, a2, b1, b2): """ Checks whether two line segments cross each other. :param a1: first endpoint of line segment a :type a1: numpy.array :param a2: second endpoint of line segment a :type a2: numpy.array :param b1: first endpoint of line segment b :type b1: numpy.array :param b2: second endpoint of line segment b :type b2: numpy.array :return: whether the line segments intersect :rtype: bool """ da = a2 - a1 db = b2 - b1 dap = perp(da) denom = np.dot(dap, db) if denom == 0: # Line segments are parallel return False dp = a1 - b1 num = np.dot(dap, dp) cx = (num / denom) * db[0] + b1[0] # x-value of intersecting point # The epsilon is added to account for floating point precision. return (cx - _eps > min(a1[0], a2[0]) and cx + _eps < max(a1[0], a2[0]) and cx - _eps > min(b1[0], b2[0]) and cx + _eps < max(b1[0], b2[0]))
def _has_fewer_crossings(edges, node_coords, goal): """ Determines whether the graph has less intersections than goal. """ np_edges = [] for n1, n2 in edges: x1, y1 = node_coords[n1] x2, y2 = node_coords[n2] p1 = np.array([x1, y1]) p2 = np.array([x2, y2]) np_edges.append((p1, p2)) num_edges = len(np_edges) crossings = 0 for i in range(num_edges - 1): for j in range(i + 1, num_edges): a1, a2 = np_edges[i] b1, b2 = np_edges[j] if seg_intersect(a1, a2, b1, b2): if crossings == goal: return False, crossings crossings += 1 return crossings < goal, crossings #=============================================================================== # Code copied and modified from networkx.drawing.layout #===============================================================================
[docs]def fruchterman_reingold_layout(G, dim=2, pos=None, fixed=None, iterations=50, weight='weight', scale=1): """ Position nodes using Fruchterman-Reingold force-directed algorithm. :param G: NetworkX graph :param dim: Dimension of layout :type dim: int :param pos: Initial positions for nodes as a dictionary with node as keys and values as a list or tuple. If None, then use random initial positions. :type pos: dict :param fixed: Nodes to keep fixed at initial position. optional :type fixed: list :param iterations: Number of iterations of spring-force relaxation :type iterations: int :param weight: The edge attribute that holds the numerical value used for the edge weight. If None, then all edge weights are 1. :type weight: str or None :param scale: Scale factor for positions :type scale: float :rtype: dict :returns: A dictionary of positions keyed by gnode Examples:: >>> G=nx.path_graph(4) >>> pos=nx.spring_layout(G) # The same using longer function name >>> pos=nx.fruchterman_reingold_layout(G) """ if fixed is not None: gnode_idx_map = {gnode: idx for idx, gnode in enumerate(G)} fixed = np.asarray([gnode_idx_map[v] for v in fixed]) if pos is not None: pos_arr = np.asarray(np.random.random((len(G), dim))) for i, n in enumerate(G): if n in pos: pos_arr[i] = np.asarray(pos[n]) else: pos_arr = None if len(G) == 0: return {} if len(G) == 1: return {next(iter(G.nodes())): (1,) * dim} A = nx.to_numpy_matrix(G, weight=weight) pos = _fruchterman_reingold(A, dim, pos_arr, fixed, iterations) if fixed is None: pos = _rescale_layout(pos, scale=scale) return dict(list(zip(G, pos)))
spring_layout = fruchterman_reingold_layout def _fruchterman_reingold(A, dim=2, pos=None, fixed=None, iterations=50): # Position nodes in adjacency matrix A using Fruchterman-Reingold # Entry point for NetworkX graph is fruchterman_reingold_layout() try: import numpy as np except ImportError: raise ImportError( "_fruchterman_reingold() requires numpy: http://scipy.org/ ") try: nnodes, _ = A.shape except AttributeError: raise nx.NetworkXError( "fruchterman_reingold() takes an adjacency matrix as input") A = np.asarray(A) # make sure we have an array instead of a matrix if pos is None: # random initial positions pos = np.asarray(np.random.random((nnodes, dim)), dtype=A.dtype) else: # make sure positions are of same type as matrix pos = pos.astype(A.dtype) # optimal distance between nodes k = np.sqrt(1.0 / nnodes) # the initial "temperature" is about .1 of domain area (=1x1) # this is the largest step allowed in the dynamics. t = STARTING_TEMPERATURE # simple cooling scheme. # linearly step down by dt on each iteration so last iteration is size dt. dt = t / (iterations + 1) delta = np.zeros((pos.shape[0], pos.shape[0], pos.shape[1]), dtype=A.dtype) # the inscrutable (but fast) version # this is still O(V^2) # could use multilevel methods to speed this up significantly for iteration in range(iterations): # matrix of difference between points for i in range(pos.shape[1]): delta[:, :, i] = pos[:, i, None] - pos[:, i] # distance between points distance = np.sqrt((delta**2).sum(axis=-1)) # enforce minimum distance of 0.01 distance = np.where(distance < 0.01, 0.01, distance) # displacement "force" displacement = np.transpose( np.transpose(delta) * (PUSH * k**PUSHEXP / distance**PUSHEXP - A * distance / k)).sum( axis=1) # update positions length = np.sqrt((displacement**2).sum(axis=1)) length = np.where(length < 0.01, 0.01, length) delta_pos = np.transpose(np.transpose(displacement) * t / length) if fixed is not None: # don't change positions of fixed nodes delta_pos[fixed] = 0.0 pos += delta_pos # cool temperature t -= dt return pos def _rescale_layout(pos, scale=1): # rescale to (0,pscale) in all axes # shift origin to (0,0) lim = 0 # max coordinate for all axes for i in range(pos.shape[1]): pos[:, i] -= pos[:, i].min() lim = max(pos[:, i].max(), lim) # rescale to (0,scale) in all directions, preserves aspect for i in range(pos.shape[1]): pos[:, i] *= scale / lim return pos