Source code for schrodinger.stepper.viz

"""
Visualization tools for generating plots based on stepper runs and workflows.

EXAMPLE CMDLINE USAGE

To create a timeline plot from a statistics.json created from a workflow
run:

    $SCHRODINGER/run python3 -m schrodinger.stepper.viz statistics.json \
    timeline.html --timeline


To create an outputs plot from a statistics.json created from a workflow
run:

    $SCHRODINGER/run python3 -m schrodinger.stepper.viz statistics.json \
    outputs_plot.html --outputs_plot --outputs_plot_depth 1

"""
import argparse
import json
from typing import Dict
from typing import List
import uuid

from schrodinger.stepper import stepper


def _step_to_timeline_row(step_id: str, stats: Dict) -> str:
    """
    Convert a step in the stats dictionary into a Google chart row. Rows are
    grouped by sections which set as the first two parts of a step id. For
    example the step id:

        MyWorkflow.SectionA.ComputeStep

    would have a section "MyWorkflow.SectionA" and an row id of "ComputeStep"
    """

    step_id_list = step_id.split('.')
    step_section = '.'.join(step_id_list[:2])
    step_id = '.'.join(step_id_list[2:])
    return ("[ "
            f"'{step_section}',"
            f"'{step_id}',"
            f"{_step_info_to_js_date(stats['start_time'])},"
            f"{_step_info_to_js_date(stats['end_time'])}"
            " ],\n")


def _step_info_to_js_date(time: str) -> str:
    """
    '2021-10-01 18:43:01 GMT' -> 'new Date(2021, 10, 01, 18, 43, 01)'
    """
    date, time = time.split(' ')[:2]
    time_list = [*date.split('-'), *time.split(':')]
    return f"new Date({', '.join(time_list)})"


def _workflow_or_stats_fname_to_dict(workflow_or_stats_fname: str) -> Dict:
    """
    Converts a workflow or a stats file into a statistics dictionary
    """
    if isinstance(workflow_or_stats_fname, stepper._BaseStep):
        stats_dict = workflow_or_stats_fname.getRunInfo()
        if len(stats_dict) == 0:
            raise RuntimeError("Step has not been run.")
    else:
        with open(workflow_or_stats_fname, 'r') as stats_file:
            stats_dict = json.load(stats_file)
    return stats_dict


def _filter_steps_for_depth(list_of_step_ids: List, depth: int) -> List:
    """
    Filter through list of step ids for the input depth and return a list
    of step ids to plot.

    Example I/Os:
        (list_of_step_ids, depth) => filtered list of step ids that should be plot
        1. (['a', 'a.b', 'a.c'], 0) => ['a']
        2. (['a', 'a.b', 'a.c'], 1) => ['a.b', 'a.c']
        3. (['a', 'a.b', 'a.b.d', 'a.b.e', 'a.c'], 2) => ['a.b.d', 'a.b.e', 'a.c']
        4. (['a', 'a.b', 'a.b.d', 'a.b.e', 'a.c'], 1) => ['a.b', 'a.c']

    Note: If any step doesn't have children at the specified depth level, the leaves
    of that step's branch would be plot instead. (e.g. In example 3 above, 'a.c' doesn't have children
    at a depth level of 2, so instead 'a.c', the leaf, gets plotted instead.)
    """
    list_step_ids_to_plot = []

    for idx in range(len(list_of_step_ids)):
        curr_step_id = list_of_step_ids[idx]
        curr_step_depth = curr_step_id.count('.')

        if curr_step_depth == depth:
            list_step_ids_to_plot.append(curr_step_id)
        elif curr_step_depth > depth:
            continue
        elif curr_step_depth < depth:
            if idx == len(list_of_step_ids) - 1:
                list_step_ids_to_plot.append(curr_step_id)
            elif _step_has_children(list_of_step_ids, curr_step_id):
                continue
            else:
                list_step_ids_to_plot.append(curr_step_id)
    return list_step_ids_to_plot


def _step_has_children(list_of_step_ids: list, step_id: str) -> bool:
    """
    Check if current workflow step has any children in the list of step ids
    """
    for step_id_to_check in list_of_step_ids:
        # Prevent step_id from checking against itself
        if step_id_to_check == step_id:
            continue
        if step_id_to_check.startswith(step_id):
            return True
    return False


def _step_to_outputs_plot_row(step_id: str, stats: Dict) -> str:
    """
    Convert a step in the stats dictionary into a Google chart row.

    Output format:
    "[ STEP_NAME, OUTPUTS ]" e.g. "[ 'Workflow.WordFilter_0', 121 ]"
    """
    return f"[ '{step_id}',{stats['num_outputs']} ],\n"


[docs]class VizWebpage: """ A class for creating a visualization webpage for different stepper workflow charts. Example usage:: workflow = MyStepperWorkflow() workflow.setInputs(range(10)) op = workflow.getOutputs() with VizWebpage('my_summary_charts.html') as page: page.addChart(TimelinePlot(workflow)) page.addChart(OutputsPlot(workflow, depth=1)) """
[docs] def __init__(self, fname): self._fname = fname self._charts = []
[docs] def addChart(self, chart: 'AbstractPlot'): self._charts.append(chart)
def __enter__(self, *args): return self def __exit__(self, *args): self.write()
[docs] def write(self): with open(self._fname, 'w') as viz_file: viz_file.write("<html>\n") viz_file.write(self._getPageHeader()) viz_file.write(self._getPageBody()) viz_file.write("</html>\n")
def _getPageHeader(self): header = """ <head> <script type="text/javascript" src="https://www.gstatic.com/charts/loader.js"></script> """ for chart in self._charts: header += chart.getJSScript() header += """ </head>""" return header def _getPageBody(self): body = """ <body> <div >""" for chart in self._charts: body += chart.getHtml() body += """ </div> </body>""" return body
[docs]class AbstractPlot: """ A base class for implementing stepper charts. Subclasses must implement: - getJSScript - getHtml """
[docs] def __init__(self): self._unique_div_id = str(uuid.uuid4())[:10]
[docs] def getJSScript(self) -> str: """ Must be implemented by subclasses. Expects the requisite javascript to generate the plot. This will usually look something like:: "<script><!--do javascript stuff that generates plot---></script>" This script is added to the header of the visualization webpage. """ raise NotImplementedError
[docs] def getHtml(self) -> str: """ Must be implemented by subclasses. Expects the requisite html to generate the plot. This is usually just a div with an id that the javascript references. e.g. "<div id='my_plot_1'></div>" This div is added to the body of the visualization webpage. """ raise NotImplementedError
class _AbstractStatsPlot(AbstractPlot): """ Similar to AbstractPlot but exposes a _stats_dict property that is a dict representing run statistics. """ def __init__(self, workflow_or_stats_fname): """ Take either a completed workflow or the filename of a written out `run_info` and timeline plot that shows what steps were running and for how long. :param workflow_or_stats_fname: A chain or step that has been run or a json file written from a step's `run_info` (e.g. `json.dump(my_chain.getRunInfo(), open('statistics.json', 'w'))`) :raises RuntimeError: If the given step has not yet been run. """ super().__init__() self._workflow_or_stats_fname = workflow_or_stats_fname self._unique_div_id = str(uuid.uuid4())[:10] @property def _stats_dict(self): return _workflow_or_stats_fname_to_dict(self._workflow_or_stats_fname)
[docs]class TopologyChart(AbstractPlot): """ Plots the hierarchy of a workflow (i.e. a tree where parents are chains and children are substeps) """
[docs] def __init__(self, workflow: stepper._BaseStep): super().__init__() self._workflow = workflow
def _getJSHead(self) -> str: return """ <script type="text/javascript"> google.charts.load('current', {packages:["orgchart"]}); google.charts.setOnLoadCallback(drawChart); function drawChart() { var data = new google.visualization.DataTable(); data.addColumn('string', 'StepID'); data.addColumn('string', 'ParentStepID'); data.addColumn('string', 'ToolTip'); // For each orgchart box, provide the name, manager, and tooltip to show. data.addRows([ """ def _workflowAsJSRows(self, step) -> list: #[str] if '.' in step.getStepId(): parent = step.getStepId().rsplit('.', 1)[0] else: parent = '' exact_node_id = step.getStepId().rsplit('.', 1)[-1] if tooltip := step.__doc__: tooltip = step.__doc__.replace('`', '') else: tooltip = '' rows = [ f"[{{'v':'{step.getStepId()}', 'f':'{exact_node_id}'}}, '{parent}', `{tooltip}`]," ] if isinstance(step, stepper.UnbatchedChain): for substep in step: rows.extend(self._workflowAsJSRows(substep)) return rows
[docs] def getJSScript(self) -> str: steps_as_rows = self._workflowAsJSRows(self._workflow) return (self._getJSHead() + '\n'.join(steps_as_rows) + self._getJSFoot())
def _getJSFoot(self) -> str: return f""" ]); // Create the chart. var chart = new google.visualization.OrgChart(document.getElementById('{self._unique_div_id}')); // Draw the chart, setting the allowHtml option to true for the tooltips. chart.draw(data, {{'allowCollapse':true, 'allowHtml':true}}); }} </script type="text/javascript"> """
[docs] def getHtml(self): return f'<div id="{self._unique_div_id}" style="width: 100%; height: 50%;"></div>'
[docs]class TimelinePlot(_AbstractStatsPlot): """ See https://developers.google.com/chart/interactive/docs/gallery/timeline for an example timeline plot. Stepper timeline charts visualize what steps are running and for how long. These charts are useful for profiling and determining where bottlenecks are in the workflow. """ def _getJSHead(self) -> str: return """ <script type="text/javascript"> google.charts.load('current', {'packages':['timeline']}); google.charts.setOnLoadCallback(drawChart); function drawChart() { var container = document.getElementById('%s'); var chart = new google.visualization.Timeline(container); var dataTable = new google.visualization.DataTable(); dataTable.addColumn({ type: 'string', id: 'Section' }); dataTable.addColumn({ type: 'string', id: 'Name' }); dataTable.addColumn({ type: 'date', id: 'Start' }); dataTable.addColumn({ type: 'date', id: 'End' }); dataTable.addRows([ """ % self._unique_div_id
[docs] def getJSScript(self) -> str: js_script = self._getJSHead() stats_dict = self._stats_dict for step_id, stats in stats_dict.items(): try: js_script += _step_to_timeline_row(step_id, stats) except KeyError: print(f"WARNING: {step_id} has missing data") js_script += self._getJSFoot() return js_script
def _getJSFoot(self) -> str: return """ ]); chart.draw(dataTable); } </script> """
[docs] def getHtml(self) -> str: return f''' <div id="{self._unique_div_id}" style="height: 50%;"></div> <div style="text-align: center">Time</div> '''
[docs]class OutputsPlot(_AbstractStatsPlot): """ An area chart visualization of the outputs per step in the workflow. See https://developers.google.com/chart/interactive/docs/gallery/areachart for an example of what these look like. """
[docs] def __init__(self, fname, depth): super().__init__(fname) self._depth = depth
def _getJSHead(self): return """ <script type="text/javascript"> google.charts.load('current', {'packages':['corechart']}); google.charts.setOnLoadCallback(drawChart); function drawChart() { var data = google.visualization.arrayToDataTable([ ['Step', 'Output'], """
[docs] def getJSScript(self): js_script = self._getJSHead() stats_dict = self._stats_dict list_of_step_ids = list(stats_dict.keys()) list_step_ids_to_plot = _filter_steps_for_depth(list_of_step_ids, self._depth) for step_id in list_step_ids_to_plot: curr_step_stats = stats_dict[step_id] try: js_script += _step_to_outputs_plot_row(step_id, curr_step_stats) except KeyError: print(f"WARNING: {step_id} has missing data") js_script += self._getJSFoot() return js_script
def _getJSFoot(self): return """ ]); var options = { title: 'Outputs per step', chartArea: { top: 55, height: '25%%' }, hAxis: { title: 'Step', slantedText: true, slantedTextAngle: 310, textStyle: {fontSize: 12}, titleTextStyle: {color: '#0000FF'} }, vAxis: {minValue: 0, logScale: true} }; var chart = new google.visualization.AreaChart(document.getElementById('%s')); chart.draw(data, options); } </script> """ % self._unique_div_id
[docs] def getHtml(self): return f'<div id="{self._unique_div_id}" style="width: 100%; height: 50%;"></div>'
[docs]def main(): parser = argparse.ArgumentParser( description='Visualization tools for generating plots based on ' 'stepper runs and workflows.') parser.add_argument('input_file_path', type=str, help='File path to statistics.json') parser.add_argument('output_file_name', type=str, help='Desired name for output HTML file') parser.add_argument( '--timeline', action='store_true', default=False, help='Whether or not to create a timeline visualization') parser.add_argument( '--outputs_plot', action='store_true', default=False, help='Whether or not to create an outputs plot visualization, ' 'requires --outputs_plot_depth') parser.add_argument( '--outputs_plot_depth', type=int, help='Depth of workflow to plot in outputs plot, e.g. \'-opd 2\'', metavar='') args = parser.parse_args() with VizWebpage(args.output_file_name) as output_page: if args.timeline: output_page.addChart(TimelinePlot(args.input_file_path)) print(f"Timeline written to {args.output_file_name}") if args.outputs_plot: if not args.outputs_plot_depth: raise RuntimeError( 'Expected --outputs_plot_depth to be ' 'specified when plotting outputs plot; e.g. ' '\'input_file_path output_file_name -op -opd 2\'') output_page.addChart( OutputsPlot(args.input_file_path, args.outputs_plot_depth)) print(f"Outputs plot written to {args.output_file_name}")
if __name__ == '__main__': main()