Source code for schrodinger.protein.predictors

"""
This module contains classes that wrap prime backends that predict sequence
structures. Many of the parameters and class constants are from a time when
documentation was sparse. In the future, it's possible we'll tweak these
numbers as needed.
"""
import enum
import os

from schrodinger.job.util import hunt
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.protein import alignment
from schrodinger.protein import annotation
from schrodinger.protein import residue
from schrodinger.protein import sequence
from schrodinger.protein.constants import SSA_MAP
from schrodinger.tasks import tasks
from schrodinger.utils import fileutils

SEQ_ANNO_TYPES = annotation.ProteinSequenceAnnotations.ANNOTATION_TYPES
PSP_DATA_DIR = os.path.join(hunt('psp', 'data'))


[docs]class AbstractPredictor(tasks.SubprocessCmdTask): """ Base class for all predictors. Derived classes are expected to implement class constants for: - EXE - A string that should match to the predictors executable. Most of the time this is the same as PREDICTOR_NAME - PREDICTOR_NAME - A string with the name of the predictor. This is used to find the Prime data directory that holds the model parameters used by the predictor. - CLASS_NUM - A parameter specific to the predictor. Usually found by looking through the Prime predictors source code. - NU - Another model parameter. - NY - Another model parameter. In addition, derived classes should implement the following methods: generateInputFile - Should generate the required input file at the file described by `input_fname` prediction - Should read `self.getLogAsString()` and parse out the actual prediction from the backend makeCmd - This only needs to be implemented if the backend takes a command different from the form: `executable model_fname input_fname` """ EXE = NotImplemented PREDICTOR_NAME = NotImplemented CLASS_NUM = NotImplemented NU = NotImplemented NY = NotImplemented
[docs] class Input(parameters.CompoundParam): seq: sequence.ProteinSequence = None aln: alignment.ProteinAlignment = None
input = Input() input_fname: str model_fname: str
[docs] def generateInputFile(self): """ Generate the input file for the predictor. Typically includes a header (see `_getInputHeader`), the file name of the blast alignment, and the sequence to predict properties for. The input file should be written with the name `self.input_fname`. """ raise NotImplementedError
[docs] def prediction(self): """ Return the actual prediction. This can take various forms depending on the predictor. """ raise NotImplementedError
[docs] def generateAlignmentFile(self): """ Write the alignment file to be used as an input for the predictor. The file will be a temporary file and will be removed in `.postprocess`. Gaps in the alignment file are represented as '.'. """ self._aln_file = fileutils.tempfilename(suffix=None, temp_dir=os.getcwd()) self._aln_fname = os.path.basename(self._aln_file) with open(self._aln_fname, 'w', newline="\n") as aln_file: aln_file.write(str(len(self.input.aln)) + '\n') for seq in self.input.aln: # Predictor alignment files use '.' for gap characters seq_str = str(seq).replace(seq.gap_char, '.') aln_file.write(seq_str + '\n')
@tasks.postprocessor def _cleanUpTmpFiles(self): self._aln_file.remove()
[docs] def generateModelFile(self): """ Generate the model definition file with the name `self.model_fname`. This is done by finding the Prime data directory for the predictor and getting the names of all the files in it. The model file includes a header describing the number of model files and the predictors class number (`self.CLASS_NUM`), and a list of the model files. """ model_dir = os.path.join(PSP_DATA_DIR, 'predictors', self.PREDICTOR_NAME) model_list = [ os.path.join(model_dir, model_file) for model_file in os.listdir(model_dir) ] model_list_str = ' \n'.join(sorted(model_list)) model_file_name = self.model_fname with open(model_file_name, "w", newline="\n") as model_file: model_file.write(f"{len(model_list)} {self.CLASS_NUM}\n") model_file.writelines(model_list_str) model_file.write(' \n')
@tasks.preprocessor def _generateInputFiles(self): """ Generate the input and model files. """ self.generateAlignmentFile() self.generateInputFile() self.generateModelFile()
[docs] def makeCmd(self): """ Return the command to run the predictor backend. The default implementation returns the predictor executable, the model file name, and the input file name. :rtype: list[str] """ exe_path = self._getPredictorExe() return [exe_path, self.model_fname, self.input_fname]
def _getPredictorExe(self): """ Find the absolute path of the predictor executable. The executable is searched for within the `psp` build. """ return os.path.join(hunt('psp'), self.EXE) def _getInputHeader(self): """ Return the header for the input file. For most predictors, the header contains the number of sequences to predict properties for, the NU parameter, and the NY parameter. All the predictors in the module predict properties for only one sequence at a time. Note that this is a convenience function useful for most predictors but not necessarily /all/ predictors. """ num_seqs = 1 model_params = f'{num_seqs} {self.NU} {self.NY}\n' return model_params
[docs] def postprocess(self): self._aln_file.remove() super().postprocess()
[docs]class SsproPredictor(AbstractPredictor): """ Secondary structure predictor. """ EXE = 'sspro4' PREDICTOR_NAME = 'sspro' CLASS_NUM = 3 NU = 20 NY = 3 input_fname: str = 'sspro.inp' model_fname: str = 'sspro_model.def'
[docs] def generateInputFile(self): seq = self.input.seq with open(self.input_fname, 'w', newline="\n") as input_file: header = self._getInputHeader() input_file.write(header) input_file.write(self._aln_fname + "\n") input_file.write(str(seq) + "\n") fake_ssa = 'H' * len(seq) input_file.write(fake_ssa + "\n")
[docs] def makeCmd(self): """ Usage: $PSP_PATH/sspro4 model_definition dataset_file alignment_directory dataset_format """ # See sspro for more information on the different formats. # This class only supports format 0 currently. alignment_directory = './' dataset_format = '0' return super().makeCmd() + [alignment_directory, dataset_format]
def _validateStdout(self): stdout = self.getLogAsString() split_stdout = stdout.split('\n') if len(split_stdout) < 2: err_msg = ("Predictor returned incorrectly formatted output. \n" "Predictor stdout:\n" + stdout) raise RuntimeError(err_msg) if any(c not in SSA_MAP for c in split_stdout[1]): err_msg = ("Got unexpected character in output.\n" "Predictor stdout:\n" + stdout) raise RuntimeError(err_msg)
[docs] def rawPrediction(self): """ :return: The raw prediction string containing one character per residue in the input sequence. :rtype: str """ stdout = self.getLogAsString() return stdout.split()[1].strip()
[docs] def prediction(self): """ :return: A list of ssa types from `structure`, one for each element in `self.input.sequence` :rtype: list """ self._validateStdout() ssa = [SSA_MAP[c] for c in self.rawPrediction()] return ssa
SolventAccessibility = jsonable.JsonableEnum('SolventAccessibility', 'BURIED EXPOSED')
[docs]class AccproPredictor(AbstractPredictor): """ Solvent accessibility predictor. """ EXE = 'accpro' PREDICTOR_NAME = 'accpro' CLASS_NUM = 20 NU = 20 NY = 3 input_fname: str = 'accpro.inp' model_fname: str = 'accpro_model.def' CHAR_TO_ACC_MAP = { 'e': SolventAccessibility.EXPOSED, 'b': SolventAccessibility.BURIED }
[docs] def generateInputFile(self): seq = self.input.seq with open(self.input_fname, 'w', newline="\n") as input_file: header = self._getInputHeader() input_file.write(header) input_file.write(self._aln_fname + "\n") input_file.write(str(seq) + "\n")
[docs] def makeCmd(self): """ Usage: $PSP_PATH/accpro model_definition dataset_file alignment_directory dataset_format threshold_index """ alignment_directory = './' dataset_format = '2' threshold_index = '5' return super().makeCmd() + [ alignment_directory, dataset_format, threshold_index ]
[docs] def rawPrediction(self): """ Example: eeebbbebebebebbebbebebeebbbbbbbeeeee e = exposed b = buried """ stdout = self.getLogAsString() return stdout.strip()
[docs] def prediction(self): return [self.CHAR_TO_ACC_MAP[c] for c in self.rawPrediction()]
INVERSE_ACC_MAP = {v: k for k, v in AccproPredictor.CHAR_TO_ACC_MAP.items()}
[docs]def encode_acc(acc): return ''.join(INVERSE_ACC_MAP[ac] for ac in acc)
# TODO: Confirm what these values should be Disordered = jsonable.JsonableEnum('Disordered', 'HIGHSCORE MEDIUMSCORE LOWSCORE')
[docs]class SsAccDependentPredictors(AbstractPredictor): """ Base class for predictors that use secondary structure and solvent accessibility predictions as inputs. """
[docs] class Input(parameters.CompoundParam): seq: sequence.ProteinSequence = None aln: alignment.ProteinAlignment = None ss_prediction: str = None acc_prediction: str = None
input = Input() @tasks.preprocessor(order=-1) def _generateSsInput(self): input = self.input seq = self.input.seq if input.ss_prediction is None: no_prediction = (len(seq.pred_secondary_structures) == 1 and seq.pred_secondary_structures[0][1] is None) if no_prediction: pred = predict_secondary_structure(input.seq, input.aln, mutate_in_place=False) input.ss_prediction = pred.rawPrediction() else: ssa = (res.pred_secondary_structure for res in seq.residues()) ssa_string = encode_ssa(ssa) input.ss_prediction = ssa_string @tasks.preprocessor(order=-1) def _generateAccInput(self): input = self.input seq = input.seq if input.acc_prediction is None: acc_predictions = [res.pred_accessibility for res in seq] if all(p is None for p in acc_predictions): pred = predict_solvent_accessibility(input.seq, input.aln) input.acc_prediction = pred.rawPrediction() else: acc = (res.pred_accessibility for res in seq.residues()) acc_string = encode_acc(acc) input.acc_prediction = acc_string
[docs]class DisproPredictor(SsAccDependentPredictors): """ Disordered regions predictor. """ EXE = 'dispro' PREDICTOR_NAME = 'dispro' CLASS_NUM = 2 NU = 25 NY = 2 input_fname: str = 'dispro.inp' model_fname: str = 'dispro_model.def'
[docs] def makeCmd(self): alignment_directory = './' return super().makeCmd() + [alignment_directory]
[docs] def generateInputFile(self): seq_str = str(self.input.seq) with open(self.input_fname, 'w', newline="\n") as input_file: header = self._getInputHeader() input_file.write(header) input_file.write(self._aln_fname + ' ' + str(len(seq_str)) + "\n") input_file.write(seq_str + "\n") input_file.write(self.input.ss_prediction + '\n') input_file.write(self.input.acc_prediction + '\n') input_file.write('N' * len(seq_str) + '\n')
[docs] def rawPrediction(self): stdout = self.getLogAsString() return stdout.split('\n')
[docs] def prediction(self): print(self.rawPrediction()) disordered_pred = self.rawPrediction()[1].strip() probabilities = [ float(p) for p in self.rawPrediction()[2].strip().split() ] pred = [] for dis, prob in zip(disordered_pred, probabilities): prob = float(prob) if dis == 'N': pred.append(Disordered.LOWSCORE) elif prob > 0.9: pred.append(Disordered.HIGHSCORE) else: pred.append(Disordered.MEDIUMSCORE) return pred
DomainArrangement = jsonable.JsonableEnum('DomainArrangement', 'Interdomain DomainForming')
[docs]class DomproPredictor(SsAccDependentPredictors): """ Domain arrangement predictor. """ EXE = 'dompro' PREDICTOR_NAME = 'dompro' CLASS_NUM = 2 NU = 25 NY = 3 input_fname: str = 'dompro.inp' model_fname: str = 'dompro_model.def'
[docs] def makeCmd(self): alignment_directory = './' return super().makeCmd() + [alignment_directory]
[docs] def generateInputFile(self): seq_str = str(self.input.seq) with open(self.input_fname, 'w', newline="\n") as input_file: header = self._getInputHeader() input_file.write(header) input_file.write(self._aln_fname + ' ' + str(len(seq_str)) + "\n") input_file.write(seq_str + "\n") input_file.write(self.input.ss_prediction + '\n') input_file.write(self.input.acc_prediction + '\n') input_file.write("1\n") input_file.write('N' * len(seq_str) + '\n')
[docs] def rawPrediction(self): stdout = self.getLogAsString() return stdout.split('\n')[1]
[docs] def prediction(self): return [ DomainArrangement.Interdomain if c == 'N' else DomainArrangement.DomainForming for c in self.rawPrediction() ]
[docs]class DiproPredictor(SsAccDependentPredictors): """ Disulfide bonds predictor. """ EXE = 'dipro' PREDICTOR_NAME = 'dipro' CLASS_NUM = 0.5 input_fname: str = 'dipro.inp' model_fname: str = 'dipro_model.def'
[docs] class DiproFormat(enum.IntEnum): """ For use with command line invocation. """ Alessandro = 1 NewDipro = 2
[docs] def makeCmd(self): """ Usage: $PSP_PATH/dipro model_file sequence_file alignment_file format """ cmd = super().makeCmd() cmd.extend([self._aln_fname, str(int(self.DiproFormat.NewDipro))]) return cmd
[docs] def generateInputFile(self): seq_str = str(self.input.seq) with open(self.input_fname, 'w', newline="\n") as input_file: num_sequences = 1 input_file.write(f"{num_sequences}\n") input_file.write(self._aln_fname + "\n" + str(len(seq_str)) + "\n") input_file.write(seq_str + "\n") input_file.write(self.input.ss_prediction + '\n') input_file.write(self.input.acc_prediction.replace('b', '-') + '\n')
[docs] def prediction(self): """ :return: A list of disulfide bonds represented by 2-tuples with two residue indexes :rtype: list[tuple[int]] """ body_index = 0 stdout = self.getLogAsString() output = stdout.split('\n') for idx, line in enumerate(output): if line.startswith("Bond_Index"): body_index = idx break bonds = [] for line in output[body_index + 1:]: if line == '': break _, idx_a, idx_b = line.strip().split() # bond idxs are 1-indexed so we subtract 1 here. bonds.append((int(idx_a) - 1, int(idx_b) - 1)) return bonds
[docs]class BetaproPredictor(AbstractPredictor): """ Beta strand contacts predictor """ EXE = 'betapro' PREDICTOR_NAME = 'betapro' CLASS_NUM = '' NU = 20 NY = 3 input_fname: str = 'betapro.inp' model_fname: str = 'betapro_model.def'
[docs] class Input(parameters.CompoundParam): seq: sequence.ProteinSequence = None aln: alignment.ProteinAlignment = None ss_prediction: str
input = Input()
[docs] def generateInputFile(self): seq = str(self.input.seq) with open(self.input_fname, 'w', newline="\n") as input_file: header = self._getInputHeader() input_file.write(header) input_file.write(self._aln_fname + "\n") input_file.write(str(len(seq)) + "\n") input_file.write(seq + "\n") input_file.write(self.input.ss_prediction + "\n")
[docs] def makeCmd(self): """ Usage: $PSP_PATH/betapro model_file, protein_file, alignment_file """ cmd = super().makeCmd() cmd.extend([self._aln_fname]) return cmd
[docs] def prediction(self): # For now, we just return the stdout. In the future, we'll probably # have to do some processing depending on how we want to present # this data. return self.getLogAsString()
[docs]class PredictorWrapperTask(tasks.BlockingFunctionTask): """ Task to run a specific predictor. """
[docs] def __init__(self, anno, seq, blast_ann): super().__init__() self._pred_func = PRED_ANNO_TO_PRED_FUNC[anno] self._seq = seq self._blast_ann = blast_ann
[docs] def mainFunction(self): self._pred_func(self._seq, self._blast_ann)
def _run_prediction(pred, seq, aln): pred.input.seq = seq pred.input.aln = aln pred.start() pred.wait() if pred.status is pred.FAILED: print(pred.failure_info) raise pred.failure_info.exception return pred.prediction()
[docs]def predict_secondary_structure(seq, aln, mutate_in_place=True): pred = SsproPredictor() ss_predictions = _run_prediction(pred, seq, aln) if mutate_in_place: seq.setSSAPredictions(ss_predictions) return pred
[docs]def predict_solvent_accessibility(seq, aln, mutate_in_place=True): pred = AccproPredictor() acc_predictions = _run_prediction(pred, seq, aln) if mutate_in_place: seq.setSolventAccessibilityPredictions(acc_predictions) return pred
[docs]def predict_disordered_regions(seq, aln, mutate_in_place=True): pred = DisproPredictor() dis_predictions = _run_prediction(pred, seq, aln) if mutate_in_place: seq.setDisorderedRegionsPredictions(dis_predictions) return pred
[docs]def predict_domain_arrangement(seq, aln, mutate_in_place=True): pred = DomproPredictor() dom_predictions = _run_prediction(pred, seq, aln) if mutate_in_place: seq.setDomainArrangementPredictions(dom_predictions) return pred
[docs]def predict_disulfide_bond(seq, aln, mutate_in_place=True): pred = DiproPredictor() disulfide_predictions = _run_prediction(pred, seq, aln) if mutate_in_place: for bond in seq.pred_disulfide_bonds: residue.remove_disulfide_bond(bond) for (idx1, idx2) in disulfide_predictions: residue.add_disulfide_bond(seq[idx1], seq[idx2], known=False) seq.predictionsChanged.emit() return pred
PRED_ANNO_TO_PRED_FUNC = { SEQ_ANNO_TYPES.pred_secondary_structure: predict_secondary_structure, SEQ_ANNO_TYPES.pred_accessibility: predict_solvent_accessibility, SEQ_ANNO_TYPES.pred_domain_arr: predict_domain_arrangement, SEQ_ANNO_TYPES.pred_disordered: predict_disordered_regions, SEQ_ANNO_TYPES.pred_disulfide_bonds: predict_disulfide_bond } INVERSE_SSA_MAP = {v: k for k, v in SSA_MAP.items()}
[docs]def encode_ssa(ssa): return ''.join(INVERSE_SSA_MAP[ss] for ss in ssa)