Source code for schrodinger.models.presets

import glob
import os
from enum import Enum
from typing import List

from schrodinger.models import json
from schrodinger.models.parameters import CompoundParam
from schrodinger.tasks import jobtasks
from schrodinger.ui.qt.appframework2 import jobnames
from schrodinger.utils import fileutils

_BAD_PRESET_CHARACTERS = set('*."/\\[]:;|,')
_ORDERING_FNAME = "ordering.txt"


[docs]class Direction(Enum): UP = -1 DOWN = 1
[docs]class PanelPreset(CompoundParam): serialized_panel: str
[docs]class PresetManager: model_class = None
[docs] def __init__(self, panel_name, model_class=None): self._panel_name = panel_name if model_class: self.model_class = model_class if not self.model_class: raise ValueError("Expected argument: `model_class`")
def _makePresetsDirectory(self): if not os.path.exists(self._getPresetsDirectory()): os.makedirs(self._getPresetsDirectory()) with open(self._generateOrderingFileName(), 'w'): pass
[docs] def defineInputParams(self): return []
[docs] def defineIgnoredParams(self): return []
[docs] def savePreset(self, name: str, model: CompoundParam): name = name.strip() self._makePresetsDirectory() for char in name: if char in _BAD_PRESET_CHARACTERS: raise ValueError(f"Can't name preset with character: {char}") try: presets_fname = self._getExistingPresetFileName(name) except FileNotFoundError: presets_fname = self._generatePresetFileName(name) self.savePresetToFile(presets_fname, model) ordering = self._getPresetOrdering() if name not in ordering: ordering.append(name) self._updateOrdering(ordering)
[docs] def savePresetToFile(self, fname: str, model: CompoundParam): preset = PanelPreset() input_params = self.defineInputParams() ignored_params = self.defineIgnoredParams() model_copy = self.model_class(model) if input_params: model_copy.reset(*input_params) if ignored_params: model_copy.reset(*ignored_params) preset.serialized_panel = json.dumps(model_copy) with open(fname, 'w') as save_file: json.dump(preset, save_file)
[docs] def exportPresets(self, export_fname: str): """ Exports all presets to an option file at export_fname. """ avail_presets = self.getAvailablePresets() preset_dict = {} # loads the preset jsons into a dictionary as strings to avoid extra json loading for preset_name in avail_presets: preset_fname = self._getExistingPresetFileName(preset_name) with open(preset_fname, 'r') as p_file: preset_dict[preset_name] = p_file.read() with open(export_fname, 'w') as save_file: json.dump(preset_dict, save_file)
[docs] def importPresets(self, import_fname: str, model: CompoundParam): """ Imports presets from an option file at import_fname. :raises JSONDecodeError: the opts file is not valid JSON or correct class data """ self._makePresetsDirectory() with open(import_fname, 'r') as load_file: # dictionary values are stored as strings (serialized json strings) preset_dict = json.load(load_file) ordering = self._getPresetOrdering() for name, data in preset_dict.items(): # rename the file if it already exists updated_name = name if name in ordering: updated_name = jobnames.get_next_jobname(name, name_list=ordering) preset_fname = self._generatePresetFileName(updated_name) # load the preset data into a model copy to ensure json class is valid model_copy = self.model_class(model) preset = json.loads(data, DataClass=PanelPreset) self._loadPresetFromData(preset, model_copy) with open(preset_fname, 'w') as save_file: save_file.write(data) ordering.append(updated_name) self._updateOrdering(ordering)
[docs] def loadPreset(self, name: str, model: CompoundParam): presets_fname = self._getExistingPresetFileName(name) self.loadPresetFromFile(presets_fname, model)
[docs] def loadPresetFromFile(self, fname: str, model: CompoundParam): with open(fname, 'r') as save_file: preset = json.load(save_file, DataClass=PanelPreset) self._loadPresetFromData(preset, model)
[docs] def deletePreset(self, name): try: presets_fname = self._getExistingPresetFileName(name) except FileNotFoundError: raise ValueError(f"No preset '{name}' found.") ordering = self._getPresetOrdering() ordering.pop(ordering.index(name)) os.remove(presets_fname) self._updateOrdering(ordering)
[docs] def setDefaultPreset(self, name): if name not in self.getAvailablePresets(): raise ValueError(f"No preset '{name}' found.") # Remove default if there is one. self.clearDefaultPreset() preset_fname = self._generatePresetFileName(name) new_preset_fname = self._generateDefaultPresetFileName(name) os.rename(preset_fname, new_preset_fname) ordering = self._getPresetOrdering() ordering.pop(ordering.index(name)) ordering.insert(0, name) self._updateOrdering(ordering)
[docs] def loadDefaultPreset(self, model): default_preset = self.getDefaultPreset() if default_preset is None: raise RuntimeError("No default set.") self.loadPreset(default_preset, model)
[docs] def getAvailablePresets(self): return self._getPresetOrdering()
[docs] def getDefaultPreset(self): default_glob_path = os.path.join(self._getPresetsDirectory(), "*.default.json") default_preset_fnames = list(glob.glob(default_glob_path)) if len(default_preset_fnames) > 1: raise RuntimeError("More than one default somehow saved.") elif default_preset_fnames: preset_fname = default_preset_fnames.pop() preset_fname = os.path.basename(preset_fname) return preset_fname[:preset_fname.index('.')] else: return None
[docs] def clearDefaultPreset(self): current_default = self.getDefaultPreset() if current_default is not None: current_default_fname = self._generateDefaultPresetFileName( current_default) new_fname = self._generatePresetFileName(current_default) os.rename(current_default_fname, new_fname)
[docs] def movePreset(self, name, direction): if name not in self.getAvailablePresets(): raise ValueError(f"No preset '{name}' found.") default = self.getDefaultPreset() if default and default == name: raise ValueError("Cannot move default preset") ordering = self._getPresetOrdering() old_index = ordering.index(name) new_index = old_index + direction.value up_to_default = new_index < 1 and self.getDefaultPreset() out_of_bound = new_index < 0 or new_index > len(ordering) - 1 if up_to_default or out_of_bound: raise ValueError("Invald preset move operation") ordering.pop(old_index) ordering.insert(new_index, name) self._updateOrdering(ordering)
def _loadPresetFromData(self, preset: PanelPreset, model: CompoundParam): """ Loads the preset to the given model. """ new_model_values = json.loads(preset.serialized_panel, DataClass=self.model_class) for inp_param in self.defineInputParams(): inp_param.setParamValue(new_model_values, inp_param.getParamValue(model)) for ign_param in self.defineIgnoredParams(): ign_param.setParamValue(new_model_values, ign_param.getParamValue(model)) model.setValue(new_model_values) def _getPresetOrdering(self): if not os.path.exists(self._generateOrderingFileName()): return [] with open(self._generateOrderingFileName(), 'r') as order_file: return [name.strip('\n') for name in order_file] def _generatePresetFileName(self, name): return os.path.join(self._getPresetsDirectory(), f"{name}.json") def _generateDefaultPresetFileName(self, name): return os.path.join(self._getPresetsDirectory(), f"{name}.default.json") def _generateOrderingFileName(self): return os.path.join(self._getPresetsDirectory(), _ORDERING_FNAME) def _getPresetsDirectory(self): return os.path.join(fileutils.get_directory_path(fileutils.USERDATA), 'panel_presets', self._panel_name) def _getExistingPresetFileName(self, name): if os.path.exists(self._generatePresetFileName(name)): return self._generatePresetFileName(name) elif os.path.exists(self._generateDefaultPresetFileName(name)): return self._generateDefaultPresetFileName(name) else: raise FileNotFoundError(f'No file found for preset "{name}"') def _updateOrdering(self, ordering: List[str]): """ Update the ordering of the presets to match `ordering`. If any presets are left out, they are appended to the end in arbitrary order. """ presets_glob_path = os.path.join(self._getPresetsDirectory(), "*.json") files = glob.glob(presets_glob_path) current_presets = [ self._getPresetNameFromFilePath(fname) for fname in files ] if len(current_presets) < len(ordering): raise ValueError( "More presets specified than currently saved presets") elif len(current_presets) > len(ordering): presets_in_ordering = set(ordering) for preset in current_presets: if preset not in presets_in_ordering: ordering.append(preset) for name in ordering: if name not in current_presets: raise ValueError(f"Ordering out of sync: {name} does not exist") with open(self._generateOrderingFileName(), 'w') as order_file: for name in ordering: order_file.write(f"{name}\n") def _getPresetNameFromFilePath(self, filepath): filepath = os.path.basename(filepath) filepath = filepath.replace('.default', '') filepath = filepath.replace('.json', '') return filepath
[docs]class TaskPanelPresetManager(PresetManager): """ A subclass of PresetManager that ignores the name and job config for tasks. The tasks are specified with the `panel_tasks` constructor argument which expects a list of abstract params of tasks on `model_class`. """
[docs] def __init__(self, panel_name, model_class, panel_tasks): super().__init__(panel_name, model_class) self._panel_tasks = panel_tasks
[docs] def defineIgnoredParams(self): ignored_params = [] for abstract_task in self._panel_tasks: ignored_params.append(abstract_task.name) if jobtasks.is_jobtask(abstract_task): ignored_params.append(abstract_task.job_config) return ignored_params