Source code for schrodinger.application.desmond.mplchart

"""
Tools for using matplotlib charts.

Copyright Schrodinger, LLC. All rights reserved.

"""

# Contributors: Dave Giesen

import atexit
import os
import shutil
import tempfile
from past.utils import old_div

import matplotlib
import matplotlib.font_manager as font_manager
import numpy
from matplotlib import cm as mpl_cmap
# Even though we don't use Axes3D directly, we have to import it so that
# matplotlib recognizes 3D as a valid projection.

from schrodinger.utils import fileutils
from schrodinger.utils import qapplication


[docs]def remove_temp_dir(path): """ Remove the temporary directory we created """ shutil.rmtree(path, True)
def _get_home(): """ Find user's home directory if possible. Otherwise raise error. """ path = '' try: path = os.path.expanduser("~") except: pass if not os.path.isdir(path): for evar in ('HOME', 'USERPROFILE', 'TMP'): try: path = os.environ[evar] if os.path.isdir(path): break except: pass return path def _is_writable_dir(path): """ path is a string pointing to a putative writable dir -- return True p is such a string, else False """ if path: try: with tempfile.TemporaryFile(mode='w', dir=path) as t: t.write('1') except OSError: return False else: return True return False def _get_configdir(): """ Figure out where matplotlib is going to try to write its configuration directory """ home = _get_home() cpath = os.path.join(home, '.matplotlib') if os.path.exists(cpath): if _is_writable_dir(cpath): return cpath else: if _is_writable_dir(home): try: os.mkdir(cpath) except OSError: # DESMOND-7025, race condition pass return cpath return False # This module can be run in situations where the normal matplotlib directory is # not writable. We'll use the exact method matplotlib uses to determine this to # see if we need to redirect to a temp directory. if not _get_configdir(): error, tempdir = fileutils.get_directory(fileutils.TEMP) if error or not tempdir or not _is_writable_dir(tempdir): tempdir = tempfile.gettempdir() newconfigdir = os.path.join(tempdir, '.matplotlib') if not os.path.exists(newconfigdir): os.mkdir(newconfigdir) atexit.register(remove_temp_dir, newconfigdir) os.environ['MPLCONFIGDIR'] = newconfigdir COLOR_NAME = [ "black", "red", "green", "blue", "purple", "yellow", "orange", "violet", "skyblue", "gold", "grey", ] MARKER_TRANSLATION = { "cross": "+", "rectangle": "s", "diamond": "d", "circle": "o", "square": "s", "x": "x", "arrow": "^" } DEFAULT_FONTSIZE = 'x-small' default_size = font_manager.FontManager().get_default_size() FONT_SCALINGS = { 'xx-small': 0.579 * default_size, 'x-small': 0.694 * default_size, 'small': 0.833 * default_size, 'medium': 1.0 * default_size, 'large': 1.200 * default_size, 'x-large': 1.440 * default_size, 'xx-large': 1.728 * default_size }
[docs]def prevent_overlapping_x_labels(canvas, axes_number=0): """ Given a canvas that contains a figure that contains at least one axes instance, checks the x-axis tick labels to make sure they don't overlap. If they do, the number of ticks is reduced until there is no overlap. :type canvas: matplotlib canvas object :param canvas: the canvas that contains the figure/axes objects :type axes_number: int :param axes_number: the index of the axes on the figure to examine. Default is 0, which is the first set of axis added to the figure. """ # Force the canvas to draw, or it won't have determined the tick marks # Disable draw due to https://github.com/matplotlib/matplotlib/issues/10874 #canvas.draw() xaxis = canvas.figure.get_axes()[axes_number].get_xaxis() labels = xaxis.get_majorticklabels() overlap = True # matplotlib will insist on keeping a minimum number of labels (some of # which may be blank), so to be safe we need to set a maximum number of # times we'll try to reduce the number of labels. max_tries = len(labels) - 1 tries = 0 while overlap and tries < max_tries: tries = tries + 1 overlap = False for index, label in enumerate(labels): if index: old_label_pos = labels[index - 1].get_window_extent().get_points() new_label_pos = labels[index].get_window_extent().get_points() # For unknown reason, sometimes there are labels of zero area. # Excludes such labels from comparison. if (old_label_pos[0][0] == old_label_pos[1][0] or new_label_pos[0][0] == new_label_pos[1][0]): continue # Determine if the left side of this label overlaps the right # side of the previous label old_right_x = old_label_pos[1][0] new_left_x = new_label_pos[0][0] if label.get_text(): if new_left_x - old_right_x < 1: overlap = True break if overlap: # Reduce the number of ticks by one to make room for the labels locator = xaxis.get_major_locator() # Number of bins = number of ticks - 1, so to get one fewer tick, we # need two fewer bins than the current number of ticks. locator.set_params(nbins=len(labels) - 2) # Must reset the ticks - matplotlib leaves an extra tick at the end # when it figures out new ticks. This at least makes sure the last # tick is blank. xaxis.reset_ticks() canvas.draw() labels = xaxis.get_majorticklabels()
[docs]def get_xy_plot_widget(xvals, *ylists, **kw): """ Create a scatter or line chart. The line chart may optionally have error bars associated with it. Multiple series can be plotted by passing in more than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter') The plot is returned in a QFrame widget. :type xvals: list :param xvals: the x values to plot :type ylists: one or more lists :keyword ylists: Each y series to plot should be given as an argument to the function, and each should be the same length as x :type err_y: list of lists :keyword err_y: the i'th item of err_y is a list of numerical error bars, one for each item in the i'th y list :type chart_type: str :keyword chart_type: type of chart to produce - scatter: scatterplot - line: line (default) :type marker: tuple :keyword marker: tuple of (symbol, color, size), only used for scatter plots: - symbol (1-character str) - s - square ('square', rectangle accepted) - o - circle ('circle' accepted) - ^ - triangle up ('arrow' accepted) - > - triangle right - < - triangle left - v - triangle down - d - diamond ('diamond' accepted) - p - pentagon - h - hexagon - 8 - octagon - + - plus ('cross' accepted) - x - x - color (str): - black - red - green - blue - purple - yellow - orange - violet - skyblue - gold - grey - size (int) :type size: tuple :keyword size: (x, y) plot size :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type x_range: tuple :keyword x_range: (min, max) values for the X-axis :type y_range: tuple :keyword y_range: (min, max) values for the Y-axis :type color: list :keyword color: list of color names to cycle through. See marker:color for some color names. :type bg: str :keyword bg: color name for the plot background. See marker:color for some color names. :type legend: list :keyword legend: list of strings, each item is the name of a y series in the legend :type title: str :keyword title: the title of the plot :type dpi: int :keyword dpi: dots per inch for the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large :rtype: QFrame :return: The QFrame widget that contains the plot """ from schrodinger.ui.qt import smatplotlib from schrodinger.ui.qt import swidgets # Check for a PyQt application instance and create one if needed: app = qapplication.get_application() # Grab some of the keywords, the rest will be passed on dpi = int(kw.get('dpi', 100)) plotsize = kw.get('size', (300, 200)) width = old_div(plotsize[0], dpi) height = old_div(plotsize[1], dpi) # Create the plot and return it frame = QtWidgets.QFrame() layout = swidgets.SVBoxLayout(frame) canvas = smatplotlib.SmatplotlibCanvas(width=width, height=height, dpi=dpi, layout=layout) create_mpl_plot_on_figure(canvas.figure, xvals, *ylists, **kw) # Make sure the tick labels don't overlap prevent_overlapping_x_labels(canvas) return frame
[docs]def get_xy_plot(xvals, *ylists, **kw): """ Create a scatter or line chart. The line chart may optionally have error bars associated with it. Multiple series can be plotted by passing in more than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter') The plot is saved as an image file, the filename parameter should contain the path to the file to write. filename is written and the value of filename is simply returned. :type xvals: list :param xvals: the x values to plot :type ylists: one or more lists :keyword ylists: Each y series to plot should be given as an argument to the function, and each should be the same length as x :type err_y: list of lists :keyword err_y: the i'th item of err_y is a list of numerical error bars, one for each item in the i'th y list :type chart_type: str :keyword chart_type: type of chart to produce - scatter: scatterplot - line: line (default) :type marker: tuple :keyword marker: tuple of (symbol, color, size), only used for scatter plots - symbol (1-character str) - s - square ('square', rectangle accepted) - o - circle ('circle' accepted) - ^ - triangle up ('arrow' accepted) - > - triangle right - < - triangle left - v - triangle down - d - diamond ('diamond' accepted) - p - pentagon - h - hexagon - 8 - octagon - + - plus ('cross' accepted) - x - x - color (str) - black - red - green - blue - purple - yellow - orange - violet - skyblue - gold - grey - size (int) :type size: tuple :keyword size: (x, y) plot size :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type x_range: tuple :keyword x_range: (min, max) values for the X-axis :type y_range: tuple :keyword y_range: (min, max) values for the Y-axis :type color: list :keyword color: list of color names to cycle through. See marker:color for some color names. :type bg: str :keyword bg: color name for the plot background. See marker:color for some color names. :type legend: list :keyword legend: list of strings, each item is the name of a y series in the legend :type title: str :keyword title: the title of the plot :type dpi: int :keyword dpi: dots per inch for the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large :type filename: str :keyword filename: The pathway to a file that the image of this plot should be saved in. :type format: str :param format: the image format to save the chart in. Must be a matplotlib-recognized format argument to Figure.savefig(format=nnn). Default is nnn='png' :rtype: filename :return: The filename the image of the plot was saved into (same string as passed in with the filename keyword). """ from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.figure import Figure # Grab some of the keywords, the rest will be passed on dpi = int(kw.get('dpi', 100)) plotsize = kw.get('size', (300, 200)) width = old_div(plotsize[0], dpi) height = old_div(plotsize[1], dpi) # Create the basic plot area figure = Figure(figsize=(width, height), dpi=dpi) canvas = FigureCanvasAgg(figure) create_mpl_plot_on_figure(figure, xvals, *ylists, **kw) # Make sure the tick labels don't overlap prevent_overlapping_x_labels(canvas) # Save the plot to a png file filename = kw.get('filename') format = kw.get('format', 'png') figure.savefig(filename, dpi=dpi, orientation='landscape', format=format) return filename
[docs]def create_mpl_plot_on_figure(figure, xvals, *ylists, **kw): """ Create a scatter or line chart. The line chart may optionally have error bars associated with it. Multiple series can be plotted by passing in more than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter') :type figure: matplotlib Figure object :param figure: the Figure object the plot should be created on. :type xvals: list :param xvals: the x values to plot :type ylists: one or more lists :keyword ylists: Each y series to plot should be given as an argument to the function, and each should be the same length as x :type err_y: list of lists :keyword err_y: the i'th item of err_y is a list of numerical error bars, one for each item in the i'th y list :type chart_type: str :keyword chart_type: type of chart to produce - scatter: scatterplot - line: line (default) :type marker: tuple :keyword marker: tuple of (symbol, color, size), only used for scatter plots: - symbol (1-character str): - s - square ('square', rectangle accepted) - o - circle ('circle' accepted) - ^ - triangle up ('arrow' accepted) - > - triangle right - < - triangle left - v - triangle down - d - diamond ('diamond' accepted) - p - pentagon - h - hexagon - 8 - octagon - + - plus ('cross' accepted) - x - x - color (str): - black - red - green - blue - purple - yellow - orange - violet - skyblue - gold - grey - size (int) :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type x_range: tuple :keyword x_range: (min, max) values for the X-axis :type y_range: tuple :keyword y_range: (min, max) values for the Y-axis :type color: list :keyword color: list of color names to cycle through. See marker:color for some color names. :type bg: str :keyword bg: color name for the plot background. See marker:color for some color names. :type legend: list :keyword legend: list of strings, each item is the name of a y series in the legend :type title: str :keyword title: the title of the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large """ title = kw.get('title') bg = kw.get('bg', 'white') legend = kw.get('legend', []) x_label = kw.get('x_label') y_label = kw.get('y_label') colors = kw.get('color', COLOR_NAME[:]) layout = kw.get('layout') fontsize = kw.get('fontsize', DEFAULT_FONTSIZE) figure.set_facecolor(bg) chart_type = kw.get('chart_type', 'line') if chart_type == "scatter": symbol, color, size = kw.get('marker', ('o', 'skyblue', 0)) symbol = MARKER_TRANSLATION.get(symbol, symbol) xstart = 0.20 if legend: xextent = 1.0 - (xstart + .20) else: xextent = 1.0 - (xstart + .10) plot = figure.add_axes([xstart, .2, xextent, .7]) plot.set_facecolor(bg) plot.tick_params(labelsize=fontsize) # Remove the axis lines and ticks on the top and right plot.spines['right'].set_color('none') plot.spines['top'].set_color('none') plot.xaxis.set_ticks_position('bottom') plot.yaxis.set_ticks_position('left') # Add labels if title: plot.set_title(title, size=fontsize) if x_label: plot.set_xlabel(x_label, size=fontsize) if y_label: plot.set_ylabel(y_label, size=fontsize) # Axes ranges x_range = kw.get('x_range', (min(xvals), max(xvals))) y_range = kw.get( 'y_range', (min([min(y) for y in ylists]), max([max(y) for y in ylists]))) # Compiles error bars - if they exist, err_y ends up being a list of # [lower bound, upperbound] items for each list in y num_y = len(ylists) err_y = kw.get('err_y') if err_y: err_top = [] err_bot = [] # ylists and err_y are both lists of lists, each inner list is the data # for a series. for this_y, this_err_y in zip(ylists, err_y): for e, f in zip(this_y, this_err_y): err_top.append(e + f) err_bot.append(e - f) err_y.append([err_top, err_bot]) if 'y_range' not in kw: err_min = min(min(err_bot), min(err_top)) ymin = min(y_range[0], err_min) err_max = max(max(err_bot), max(err_top)) ymax = max(y_range[1], err_max) y_range = (ymin, ymax) plot.set_xlim(x_range) plot.set_ylim(y_range) # Plot each series for aseries in range(len(ylists)): num_cols = len(colors) color = colors[aseries % num_cols] try: label = legend[aseries] except IndexError: label = 'none' if chart_type == 'scatter': plot.scatter(xvals, ylists[aseries], c=color, marker=symbol, edgecolors='none', s=9, label=label) else: if err_y: plot.plot(xvals, ylists[aseries], c=color, label=label) # Creates data lists for points with nonzero errors. xvals_err, yvals_err, yerr_err = list( zip(*[ e for e in list( zip(xvals, ylists[aseries], err_y[aseries])) if e[2] ])) plot.errorbar(xvals_err, yvals_err, yerr=yerr_err, c=color, ls='none', label=label) else: plot.plot(xvals, ylists[aseries], c=color, label=label) # Put on a legend if legend: backend = (old_div(1.0, (xstart + xextent)) + .07, 0.50) if chart_type == 'scatter': plot.legend(prop={'size': fontsize}, frameon=False, loc='center right', scatterpoints=1, handlelength=0.1, handletextpad=0.5, bbox_to_anchor=backend) else: plot.legend(prop={'size': fontsize}, frameon=False, loc='center right', numpoints=1, handletextpad=0.3, handlelength=1.0, bbox_to_anchor=backend)
[docs]def get_2var_plot(data, **kw): """ Create a 2d contour or 3d surface plot. The plot is saved as an image file, the filename parameter should contain the path to the file to write. filename is written and the value of filename is simply returned. :type data: list of tuples :param data: the (x, y, z) values to plot. List items should be arranged so they are sorted by X and then by Y (so that the Y varies faster than X), and there is a Z value for every X/Y combination. :type chart_type: str :keyword chart_type: type of chart to produce - contour: 2d contour - surface: 3d surface - wireframe: 3d wireframe surface :type size: tuple :keyword size: (x, y) plot size :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type z_label: str :keyword z_label: Z-axis label (surface and wireframe only) :type x_range: tuple :keyword x_range: (min, max) values for the X-axis. (contour plots only) :type y_range: tuple :keyword y_range: (min, max) values for the Y-axis (contour plots only) :type x_reverse: bool :keyword x_reverse: True if the X axis should be reversed, False (default) if not :type y_reverse: bool :keyword y_reverse: True if the Y axis should be reversed, False (default) if not :type z_reverse: bool :keyword z_reverse: True if the Z axis should be reversed, False (default) if not (surface and wireframe only) :type color_map: str :param color_map: Name of a matplotlib color map :type color_range: tuple :keyword color_range: (min, max) of the color range. Values of min and below will get the minimum color, values of max and above will get the maximum color. Setting all the contour levels with the levels keyword may the be preferred way of accomplishing this. :type bg: str :keyword bg: color name for the plot background. :type legend: bool :keyword legend: True if a colorbar legend that shows the contour levels should be included, False if not (False is default) :type legend_format: str :keyword legend_format: String format specifier for the colorbar legend. This format is also used for contour labels. :type legend_orientation: str :keyword legend_orientation: Either 'vertical' (default) or 'horizontal' :type title: str :keyword title: the title of the plot :type dpi: int :keyword dpi: dots per inch for the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large :type filename: str :keyword filename: The pathway to a file that the image of this plot should be saved in. :type format: str :keyword format: the image format to save the chart in. Must be a matplotlib-recognized format argument to Figure.savefig(format=nnn). Default is nnn='png' :type fill: bool :keyword fill: True if the contours should be filled, False (default) if lines on a white background (contour plots only) :type labels: bool :keyword labels: True (default) if the contour lines should be labeled, False if not (contour plots only) :type contours: int :keyword contours: The number of contour lines to draw (default of 0 indicates that matplotlib should choose the optimal number) (contour plots only) :type levels: list :keyword levels: list of values to place contour lines at (contour plots only) :type viewpoint: tuple of 2 float :param viewpoint: (a, b) describing the point of view from which the plot is viewed, where a is the elevation and b is the rotation of the plot. Default is (45, 45). (surface and wireframe only) :rtype: filename :return: The filename the image of the plot was saved into (same string as passed in with the filename keyword). """ from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.figure import Figure # Grab some of the keywords, the rest will be passed on dpi = int(kw.get('dpi', 100)) plotsize = kw.get('size', (300, 200)) width = old_div(plotsize[0], dpi) height = old_div(plotsize[1], dpi) # Create the basic plot area figure = Figure(figsize=(width, height), dpi=dpi) canvas = FigureCanvasAgg(figure) chart_type = kw.get('chart_type', 'contour') if chart_type in ['surface', 'wireframe']: create_surface_on_figure(figure, data, **kw) else: create_contour_plot_on_figure(figure, data, **kw) # Make sure the tick labels don't overlap prevent_overlapping_x_labels(canvas) # Save the plot to a png file filename = kw.get('filename') format = kw.get('format', 'png') figure.savefig(filename, dpi=dpi, orientation='landscape', format=format) return filename
def _get_legend_format(values): """ Determine the format for legend values. The format is based on the type and span of the items in values: - If type(values[0]) is int, format will be integer - If span `(abs(max(values)-min(values))) < 1` format is 3 decimal places - If span < 10 format is 2 decimal places - If span < 100 format is 1 decimal place - Else format is integer :type values: iterable :param values: The list of values to check for type and span :rtype: str :return: A Python format string (such as '%.1f') for the legend values. """ if not values or isinstance(values[0], int): return '%d' span = abs(max(values) - min(values)) if span < 1.0: return '%.3f' elif span < 10.0: return '%.2f' elif span < 100.0: return '%.1f' else: return '%.0f'
[docs]def create_contour_plot_on_figure(figure, data, **kw): """ Create a 2d contour or 3d surface plot. The plot is saved as an image file, the filename parameter should contain the path to the file to write. filename is written and the value of filename is simply returned. :type figure: matplotlib Figure object :param figure: the Figure object the plot should be created on. :type data: list of tuples :param data: the (x, y, z) values to plot. List items should be arranged so they are sorted by X and then by Y (so that the Y varies faster than X), and there is a Z value for every X/Y combination. :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type x_range: tuple :keyword x_range: (min, max) values for the X-axis :type y_range: tuple :keyword y_range: (min, max) values for the Y-axis :type x_reverse: bool :keyword x_reverse: True if the X axis should be reversed, False (default) if not :type y_reverse: bool :keyword y_reverse: True if the Y axis should be reversed, False (default) if not :type color_map: str :param color_map: Name of a matplotlib color map :type color_range: tuple :keyword color_range: (min, max) of the color range. Values of min and below will get the minimum color, values of max and above will get the maximum color. Setting all the contour levels with the levels keyword may the be preferred way of accomplishing this. :type bg: str :keyword bg: color name for the plot background. :type legend: bool :keyword legend: True if a colorbar legend that shows the contour levels should be included, False if not (False is default) :type legend_format: str :keyword legend_format: String format specifier for the colorbar legend. This format is also used for contour labels. :type legend_orientation: str :keyword legend_orientation: Either 'vertical' (default) or 'horizontal' :type title: str :keyword title: the title of the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large :type filename: str :keyword filename: The pathway to a file that the image of this plot should be saved in. :type fill: bool :keyword fill: True if the contours should be filled, False (default) if lines on a white background :type labels: bool :keyword labels: True (default) if the contour lines should be labeled, False if not :type contours: int :keyword contours: The number of contour lines to draw (default of 0 indicates that matplotlib should choose the optimal number) :type levels: list :keyword levels: list of values to place contour lines at """ title = kw.get('title') bg = kw.get('bg', 'white') legend = kw.get('legend', False) legend_format = kw.get('legend_format', None) legend_orientation = kw.get('legend_orientation', 'vertical') x_label = kw.get('x_label') y_label = kw.get('y_label') color_map = kw.get('color_map', None) color_range = kw.get('color_range', None) layout = kw.get('layout') fontsize = kw.get('fontsize', DEFAULT_FONTSIZE) fill = kw.get('fill', False) labels = kw.get('labels', True) contours = kw.get('contours', 0) levels = kw.get('levels', []) xreverse = kw.get('x_reverse', False) yreverse = kw.get('y_reverse', False) figure.set_facecolor(bg) xstart = 0.20 plot = figure.add_axes([.13, .13, .8, .8]) plot.set_facecolor(bg) plot.tick_params(labelsize=fontsize) plot.xaxis.set_ticks_position('bottom') plot.yaxis.set_ticks_position('left') # Add labels if title: plot.set_title(title, size=fontsize) if x_label: plot.set_xlabel(x_label, size=fontsize) if y_label: plot.set_ylabel(y_label, size=fontsize) # Unroll the data list. We need a list of X of length M, a list # of Y of length N, and a 2-D array of Z sized MxN try: xvals = [] yvals = [] ztemp = [] firstx = True for point in data: pointx = point[0] if not xvals: xvals.append(pointx) elif pointx != xvals[-1]: xvals.append(pointx) # We've moved on to the next x-value, so have cycled through all # Y points firstx = False if firstx: yvals.append(point[1]) ztemp.append(point[2]) except IndexError: print('The data parameter should be tuples of (x, y, z) values') raise if not legend_format: legend_format = _get_legend_format(ztemp) # Create a 2-D array of zvalues numx = len(xvals) numy = len(yvals) if len(ztemp) != numx * numy: raise ValueError('The number of data points should be MxN, where M' 'is the number of\nunique X values and N is the ' 'number of unique Y values') zvals = numpy.zeros((len(yvals), len(xvals)), numpy.double) for xindex in range(len(xvals)): for yindex in range(len(yvals)): zvals[yindex, xindex] = ztemp.pop(0) # Axes ranges x_range = kw.get('x_range', (min(xvals), max(xvals))) y_range = kw.get('y_range', (min(yvals), max(yvals))) if xreverse: x_range = (x_range[1], x_range[0]) if yreverse: y_range = (y_range[1], y_range[0]) # Plot the contours contour_colors = None # Filled contours if fill: if levels: contour_fill = plot.contourf(xvals, yvals, zvals, levels=levels, cmap=color_map) elif contours: contour_fill = plot.contourf(xvals, yvals, zvals, contours, cmap=color_map) else: contour_fill = plot.contourf(xvals, yvals, zvals, cmap=color_map) if color_range: contour_fill.set_clim(color_range) contour_colors = 'k' # Plot the contour boundaries if levels: contour = plot.contour(xvals, yvals, zvals, levels=levels, colors=contour_colors, cmap=color_map) elif contours: contour = plot.contour(xvals, yvals, zvals, contours, colors=contour_colors, cmap=color_map) else: contour = plot.contour(xvals, yvals, zvals, colors=contour_colors, cmap=color_map) if color_range: contour_fill.set_clim(color_range) # Add boundary labels if labels: if fill: plot.clabel(contour, colors='k', fmt=legend_format, fontsize=FONT_SCALINGS[fontsize]) else: plot.clabel(contour, fontsize=FONT_SCALINGS[fontsize], fmt=legend_format) # Add the color bar if legend: if fill: colorbar = figure.colorbar(contour_fill, shrink=0.8, format=legend_format, orientation=legend_orientation) cb_axis = colorbar.ax else: colorbar = figure.colorbar(contour, shrink=0.8, format=legend_format, orientation=legend_orientation) cb_axis = colorbar.ax lines = colorbar.lines lines.set_linewidth(6) cb_axis.tick_params(labelsize=fontsize) plot.set_xlim(x_range) plot.set_ylim(y_range)
[docs]def create_surface_on_figure(figure, data, **kw): """ Create a 2d contour or 3d surface plot. The plot is saved as an image file, the filename parameter should contain the path to the file to write. filename is written and the value of filename is simply returned. :type figure: matplotlib Figure object :param figure: the Figure object the plot should be created on. :type data: list of tuples :param data: the (x, y, z) values to plot. List items should be arranged so they are sorted by X and then by Y (so that the Y varies faster than X), and there is a Z value for every X/Y combination. :type x_label: str :keyword x_label: X-axis label :type y_label: str :keyword y_label: Y-axis label :type z_label: str :keyword z_label: Y-axis label :type x_reverse: bool :keyword x_reverse: True if the X axis should be reversed, False (default) if not :type y_reverse: bool :keyword y_reverse: True if the Y axis should be reversed, False (default) if not :type z_reverse: bool :keyword z_reverse: True if the Z axis should be reversed, False (default) if not :type bg: str :keyword bg: color name for the plot background. :type legend: bool :keyword legend: True if a colorbar legend that shows the surface levels should be included, False if not (False is default). Surface only. :type legend_format: str :keyword legend_format: String format specifier for the colorbar legend. :type legend_orientation: str :keyword legend_orientation: Either 'vertical' (default) or 'horizontal' :type title: str :keyword title: the title of the plot :type fontsize: int or str :keyword fontsize: size in points, or one of the following - - xx-small - x-small - small - medium - large - x-large - xx-large :type viewpoint: tuple of 2 float :param viewpoint: (a, b) describing the point of view from which the plot is viewed, where a is the elevation and b is the rotation of the plot. Default is (45, 45). """ title = kw.get('title') bg = kw.get('bg', 'white') legend = kw.get('legend', False) legend_format = kw.get('legend_format', None) legend_orientation = kw.get('legend_orientation', 'vertical') x_label = kw.get('x_label') y_label = kw.get('y_label') z_label = kw.get('z_label') fontsize = kw.get('fontsize', DEFAULT_FONTSIZE) # Note that changing xtick changes all axis font sizes matplotlib.rc('xtick', labelsize=fontsize) xreverse = kw.get('x_reverse', False) yreverse = kw.get('y_reverse', False) zreverse = kw.get('z_reverse', False) elevation, azimuth = kw.get('viewpoint', (45, 45)) chart_type = kw.get('chart_type', 'surface') figure.set_facecolor(bg) xstart = 0.20 plot = figure.add_axes([.13, .13, .8, .8], projection='3d') plot.set_facecolor(bg) plot.tick_params(labelsize=fontsize) plot.xaxis.set_ticks_position('bottom') plot.yaxis.set_ticks_position('left') xax = plot.get_xaxis() # Add labels if title: plot.set_title(title, size=fontsize) if x_label: plot.set_xlabel(x_label, size=fontsize) if y_label: plot.set_ylabel(y_label, size=fontsize) if z_label: plot.set_zlabel(z_label, size=fontsize) # Unroll the data list. We need a list of X of length M, a list # of Y of length N, and a 2-D array of Z sized MxN try: xtemp = [] ytemp = [] ztemp = [] firstx = True for point in data: pointx = point[0] if not xtemp: xtemp.append(pointx) elif pointx != xtemp[-1]: xtemp.append(pointx) # We've moved on to the next x-value, so have cycled through all # Y points firstx = False if firstx: ytemp.append(point[1]) ztemp.append(point[2]) except IndexError: print('The data parameter should be tuples of (x, y, z) values') raise if not legend_format: legend_format = _get_legend_format(ztemp) # Axes ranges if 'x_range' in kw or 'y_range' in kw or 'z_range' in kw: # Note that setting the Z-axis range doens't appear to work at all, # while setting the X & Y ranges makes the plot go haywire. raise ValueError('Modifying axis ranges is not allowed for surface or' '\n wireframe plots due to matplotlib limitations') x_range = (min(xtemp), max(xtemp)) y_range = (min(ytemp), max(ytemp)) z_range = (min(ztemp), max(ztemp)) if xreverse: x_range = (x_range[1], x_range[0]) if yreverse: y_range = (y_range[1], y_range[0]) if zreverse: z_range = (z_range[1], z_range[0]) # Create a 2-D array of zvalues numx = len(xtemp) numy = len(ytemp) if len(ztemp) != numx * numy: raise ValueError('The number of data points should be MxN, where M' 'is the number of\nunique X values and N is the ' 'number of unique Y values') zvals = numpy.zeros((len(ytemp), len(xtemp)), numpy.double) for xindex in range(len(xtemp)): for yindex in range(len(ytemp)): zvals[yindex, xindex] = ztemp.pop(0) xvals, yvals = numpy.meshgrid(xtemp, ytemp) # Plot the contours if chart_type == 'wireframe': surface = plot.plot_wireframe(xvals, yvals, zvals, cmap=mpl_cmap.jet) else: surface = plot.plot_surface(xvals, yvals, zvals, cmap=mpl_cmap.jet) if legend: colorbar = figure.colorbar(surface, shrink=0.8, format=legend_format, orientation=legend_orientation) cb_axis = colorbar.ax cb_axis.tick_params(labelsize=fontsize) # Set the axis ranges #plot.set_xlim3d(x_range) #plot.set_ylim3d(y_range) #plot.set_zlim3d(z_range) plot.view_init(elevation, azimuth)
if ("__main__" == __name__): import os import sys from schrodinger.Qt import QtWidgets # Check for a PyQt application instance and create one if needed: #app = qapplication.get_application() if len(sys.argv) != 2: print("Usage: $SCHRODINGER/run %s <data-file>" % sys.argv[0]) sys.exit(0) if (not os.path.isfile(sys.argv[1])): print("Data file not found: %s" % sys.argv[1]) sys.exit(0) lines = open(sys.argv[1], "r").read().split("\n") ### XY plot testing #x = [] #y = [] #y_err = [] #for line in lines : # line = line.strip() # if (line != "" and line[0] != "#") : # token = line.split() # x .append( float( token[0] ) ) # y .append( float( token[1] ) ) #y_err.append( float( token[2] ) ) ##plot = get_xy_plot(x, y, err_y = [y_err], x_label="time (ps)", ##y_label='Y values', legend=['Test', 'Test2'], ##filename='test.png') ##plot = get_xy_plot(x, y, err_y = [y_err], x_label="time (ps)", #plot = get_xy_plot(x, y, x_label="time (ps)", #y_label='Y values', legend=['Test', 'Test2'], #filename='test2.png') ##frame = get_xy_plot_widget(x, y, err_y = [y_err], x_label="time (ps)", #frame = get_xy_plot_widget(x, y, x_label="time (ps)", # y_label='Y values', legend=['Test', 'Test2']) ### End XY plot testing ### Contour plot testing data = [] for line in lines: line = line.strip() if (line != "" and line[0] != "#"): tokens = line.split() linelist = [] for atoken in tokens: try: linelist.append(int(atoken)) except ValueError: linelist.append(float(atoken)) data.append(tuple(linelist)) plot = get_2var_plot(data, fill=True, legend=True, labels=False, legend_orientation='horizontal', x_label='bob', y_label='joe', title='jim', viewpoint=(45, 45), size=(500, 500), z_label='jeff', filename='contour.png', chart_type='surface') #filename='contour.png', levels=[15, 20, 25, 30]) #plot = get_2var_plot(data, fill=True, legend=True, labels=False, #legend_orientation='vertical', x_label='CV1', #y_label='CV2', size=(600, 400), filename='ev.png') #frame.show() #app.exec()