Source code for schrodinger.application.steps.filters

"""
Filter steps that take input Mols and only outputs Mols that meet certain
criteria.

NOTE::
    You can import `UniqueSmilesFilter` from this module, which will
    either be `InMemoryUniqueSmilesFilter` or `BQDedupeChain` (which will
    filter using BigQuery) depending on whether you have the environment
    variable SCHRODINGER_GCP_KEY defined.

    You can generally use `UniqueSmilesFilter` for your workflows and users
    will define SCHRODINGER_GCP_KEY as needed when they scale higher, but
    if you know a part of your workflow will be low-volume then it would
    be better to use `InMemoryUniqueSmilesFilter` instead.

    All of the above also applies to `RandomSampleFilter`, which also has
    an in-memory and bigquery variant.

"""
import copy
import os
from random import Random
from sys import maxsize
from typing import Dict
from typing import List

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import rdFMCS

from schrodinger import stepper
from schrodinger.application.pathfinder import filtering
from schrodinger.models import parameters
from schrodinger.structutils import filter
from schrodinger.structutils.smiles import STEREO_FROM_ANNOTATION_AND_GEOM
from schrodinger.structutils.smiles import SmilesGenerator

from . import bigquery_deduplication
from . import env_keys
from . import redshift_deduplication
from . import scorers
from . import utils
from .basesteps import LoggerMixin
from .basesteps import MaeMapStep
from .basesteps import MolMapStep
from .basesteps import MolMolWorkflow
from .basesteps import MolReduceStep
from .dataclasses import ScoredMol
from .dataclasses import ScoredMolSerializer
from .dataclasses import ScoredSmiles
from .dataclasses import ScoredSmilesIOMixin

INF = float('inf')
# FIXME: FILTER_TABLE be removed in AD-480 and replaced with a single param
FILTER_TABLE = (redshift_deduplication.RSTable
                if env_keys.is_aws_service_available() else
                bigquery_deduplication.BQTable)


[docs]class InMemoryUniqueRandomSampleFilter(stepper.Chain):
[docs] class Settings(parameters.CompoundParam): n: int = None seed: int = None
[docs] def buildChain(self): self.addStep(InMemoryUniqueSmilesFilter()) self.addStep( InMemoryRandomSampleFilter(n=self.settings.n, seed=self.settings.seed))
[docs]class InMemoryUniqueSmilesFilter(MolReduceStep): """ Only allow molecules through whose canonical SMILES was not previously encountered. """
[docs] def reduceFunction(self, mols): seen_smiles = set() for mol in mols: smiles = Chem.MolToSmiles(mol) if smiles not in seen_smiles: seen_smiles.add(smiles) yield mol
[docs]class InMemoryRandomSampleFilter(MolReduceStep): """ A filter that takes a random subsample of molecules. The sample size can be set through the step's settings `n`. Implementation of Algorithm R, but without knowing the size of sequence to sample from. See https://en.wikipedia.org/wiki/Reservoir_sampling """
[docs] class Settings(parameters.CompoundParam): n: int = None seed: int = None
[docs] def validateSettings(self): if self.settings.n is None: return [stepper.SettingsError(self, 'n must be set')] if self.settings.n <= 0: return [stepper.SettingsError(self, 'n must be larger than 0')] return []
[docs] def reduceFunction(self, mols): settings = self.settings random = Random(settings.seed) mols = iter(mols) # to be sure we iterate once over the first max_len max_rnd_int = max_len = self.settings.n reservoir = [] for mol in mols: reservoir.append(Chem.MolToSmiles(mol)) if len(reservoir) == max_len: break for mol in mols: replace_idx = random.randint(0, max_rnd_int) if replace_idx < max_len: smi = Chem.MolToSmiles(mol) reservoir[replace_idx] = smi max_rnd_int += 1 for smi in reservoir: yield Chem.MolFromSmiles(smi)
[docs]class CloudOptionMixin: """ A mixin that allows the task a choice between cloud based or in-memory. """
[docs] class Settings(parameters.CompoundParam): use_cloud_services: bool = False
[docs] def validateSettings(self): ret = super().validateSettings() if (self.settings.use_cloud_services and not env_keys.has_cloud_services()): return ret + [ stepper.SettingsError( self, 'use_cloud_services is set to True but no cloud ' 'environment variable key is defined!') ] return ret
[docs]class UniqueRandomSampleFilter(CloudOptionMixin, stepper.Chain):
[docs] class Settings(CloudOptionMixin.Settings): table: FILTER_TABLE n: int = None seed: int = None
[docs] def buildChain(self): if self.settings.use_cloud_services: if env_keys.SCHRODINGER_GCP_KEY: self.addStep( bigquery_deduplication.BQDedupeAndRandomSampleFilter( table=copy.deepcopy(self.settings.table), n=self.settings.n)) elif env_keys.is_aws_service_available(): self.addStep( redshift_deduplication.RSDeduplicateAndRandomSampleFilter( table=copy.deepcopy(self.settings.table), n=self.settings.n)) else: raise RuntimeError('Missing cloud service environment keys') else: self.addStep( InMemoryUniqueRandomSampleFilter(n=self.settings.n, seed=self.settings.seed))
[docs]class UniqueSmilesFilter(CloudOptionMixin, stepper.Chain): """ A link step to determine whether cloud services are setup for SMILES filtering or to default to an in-memory filter task. """
[docs] class Settings(CloudOptionMixin.Settings): table: FILTER_TABLE
[docs] def buildChain(self): if self.settings.use_cloud_services: if env_keys.SCHRODINGER_GCP_KEY: self.addStep( bigquery_deduplication.BQUniqueSmilesFilter( **self.settings.table.toDict())) elif env_keys.is_aws_service_available(): self.addStep( redshift_deduplication.RSUniqueSmilesFilter( table=copy.deepcopy(self.settings.table))) else: raise RuntimeError('Missing cloud service environment keys') else: self.addStep(InMemoryUniqueSmilesFilter())
[docs]class RandomSampleFilter(CloudOptionMixin, stepper.Chain): """ A link step to determine whether cloud services are setup for SMILES random sampling or to default to an in-memory sampling task. """
[docs] class Settings(CloudOptionMixin.Settings): table: FILTER_TABLE n: int = None seed: int = None
[docs] def buildChain(self): if self.settings.use_cloud_services: if env_keys.SCHRODINGER_GCP_KEY: self.addStep( bigquery_deduplication.BQRandomSampleFilter( table=copy.deepcopy(self.settings.table), n=self.settings.n)) elif env_keys.is_aws_service_available(): self.addStep( redshift_deduplication.RSRandomSampleFilter( table=copy.deepcopy(self.settings.table), n=self.settings.n)) else: raise RuntimeError('Missing cloud service environment keys') else: self.addStep( InMemoryRandomSampleFilter(n=self.settings.n, seed=self.settings.seed))
[docs]class MaeUniqueSmilesFilter(MaeMapStep): """ Filter structures based on unique seen SMILES. """
[docs] def setUp(self): super().setUp() self._seen_smiles = set() self._smiles_generator = SmilesGenerator( STEREO_FROM_ANNOTATION_AND_GEOM, unique=True)
[docs] def mapFunction(self, st): smiles = self._smiles_generator.getSmiles(st) if smiles not in self._seen_smiles: self._seen_smiles.add(smiles) yield st
[docs]class SmartsFilter(MolMapStep): """ Only allow molecules that have a SMARTS substructure defined in settings. """
[docs] class Settings(parameters.CompoundParam): core_smarts: str = None
[docs] def validateSettings(self): return utils.validate_core_smarts(self, self.settings.core_smarts)
[docs] def setUp(self): super().setUp() self._core_smarts = Chem.MolFromSmarts(self.settings.core_smarts) self._need_Hs = None
[docs] def mapFunction(self, mol): if self._need_Hs is None: self._need_Hs = utils.need_Hs(mol, self._core_smarts) if self._need_Hs is None: # wasn't able to match return test_mol = Chem.AddHs(mol) if self._need_Hs else mol if test_mol.HasSubstructMatch(self._core_smarts): yield mol
[docs]class ChiralCenterCountFilter(MolMapStep): """ Only allow molecules through that have the number of chiral centers that falls in the ranged determined by settings' min_value and max_value, with the borders included. """
[docs] class Settings(parameters.CompoundParam): min_value: int = 0 max_value: int = maxsize
[docs] def validateSettings(self): issues = [] if self.settings.min_value > self.settings.max_value: issues.append(stepper.SettingsError(self, 'min_value > max_value')) return issues
[docs] def mapFunction(self, mol): count = len(Chem.FindMolChiralCenters(mol, includeUnassigned=True)) if self.settings.min_value <= count <= self.settings.max_value: yield mol
[docs]class ProductFilterMixin: """ A mixin that only allows molecules through whose SMARTS substructure count passes all product filters defined in the settings. If the settings has `cflt_file` defined, it will be considered to be the path from which to create the `smarts_filter` to use. Since the `filter.SmartsFilter` can `checkStructure` with either `Mol` or `Structure` objects the behavior is completely determined by the class that it is a mixin for, i.e., either an `MolMapStep` or `MaeMapStep`. """
[docs] class Settings(parameters.CompoundParam): """ :ivar smarts_filter: the smarts filter or None :ivar cflt_file: if defined the source to """ cflt_file: stepper.StepperFile # the product filter file
def _setSmartsFilter(self): self._smarts_filter = None file = self.settings.cflt_file if file and os.path.isfile(file): self._smarts_filter = filter.SmartsFilter(filename=file)
[docs] def validateSettings(self): issues = utils.validate_file(self, 'cflt_file', required=True) if issues: return issues self._setSmartsFilter() if self._smarts_filter is None or len(self._smarts_filter.filters) < 1: issues.append( stepper.SettingsError(self, 'no product filters defined')) return issues
[docs] def setUp(self): super().setUp() self._setSmartsFilter() self.logger.info( f'{self.getStepId()} {len(self._smarts_filter.filters)} filters')
[docs] def mapFunction(self, molecule): if self._smarts_filter.checkStructure(molecule): yield molecule
[docs]class ProductFilter(ProductFilterMixin, MolMapStep): """ See `ProductFilterMixin` """ pass
[docs]class MaeProductFilter(ProductFilterMixin, MaeMapStep): """ See `ProductFilterMixin` """ pass
[docs]class ScoredSmilesProductFilter(ScoredSmilesIOMixin, ProductFilterMixin, LoggerMixin, stepper.MapStep): """ A ProductFilter using scored smiles and inputs and outputs. See `ProductFilterMixin` """
[docs] def mapFunction(self, scored_smiles): mol = Chem.MolFromSmiles(scored_smiles.smiles) if not mol: self.logger.error(f'ERROR: {self.getStepId()} invalid SMILES:' f' {scored_smiles.smiles}') elif list(super().mapFunction(mol)): yield scored_smiles
[docs]class PropertyFilter(MolMapStep): """ Only allows molecules through whose properties pass all property filters defined in the settings. If the settings has `filter_file` defined, it will be considered to be the path from which to create the filters to be used. If the filter filename has the .json extension, it will be interpreted as a JSON filter file of the kind used by a few panels, including PathFinder; that filter format is implemented by `schrodinger.ui.qt.filter_dialog_dir.filter_core`. A file without the .json extension is considered to be regular text with the format of the filter file is that of what pathfinder and canvas filter uses, e.g., lines like:: r_rdkit_TPSA > 30.0 < 150.0 i_rdkit_NumRotatableBonds < 10 i_rdkit_NumChiralCenters == 1 r_rdkit_MolWt > 150.0 < 575.0 """
[docs] class Settings(parameters.CompoundParam): filter_file: stepper.StepperFile
def _setPropertyFilter(self): self._prop_filter = None file = self.settings.filter_file if file and os.path.isfile(file): self._prop_filter = filtering.get_filter(file, smarts_filter=False) def _getCheckCount(self): if isinstance(self._prop_filter, filtering.JSONFilterAdapter): return sum( 1 for c in self._prop_filter.filter_obj.criteria if c.checked) return len(self._prop_filter.getPropertyNames())
[docs] def validateSettings(self): issues = utils.validate_file(self, 'filter_file', required=True) if issues: return issues self._setPropertyFilter() if self._prop_filter is None or self._getCheckCount() < 1: issues.append( stepper.SettingsError(self, 'no property filters defined')) return issues
[docs] def setUp(self): super().setUp() self._setPropertyFilter() self._descriptor_names = self._prop_filter.getPropertyNames() self.logger.info(f'{self.getStepId()}: {self._getCheckCount()} filters')
[docs] def mapFunction(self, mol): filtering.add_descriptors(mol, self._descriptor_names, refs=None) yield from self._prop_filter.filter([mol])
[docs]class ProfileSettings(parameters.CompoundParam): """ The `property_ranges` should be a dict where the key is the property name (a string that exists in `filtering.DESCRIPTORS_DICT`). The value should be the range of values allowed for that property where the ends are included in what is considered to be acceptable. """ property_ranges: Dict[str, List[float]]
[docs] def validate(self, step): """ Validate the settings on behalf of a step. :param step: stepper._BaseStep :rtype: list[TaskError or TaskWarning] """ if not self.property_ranges: return [stepper.SettingsWarning(step, 'no filtering will be done')] issues = [] for prop, range in self.property_ranges.items(): if prop not in filtering.DESCRIPTORS_DICT: issues.append( stepper.SettingsError(step, f'{prop} is not a valid property')) if range[0] > range[1]: issues.append( stepper.SettingsError(step, f'{prop} has an invalid range')) return issues
[docs]class ProfileFilter(MolMapStep): """ A product filter where the property profile is defined by the settings. The `complexity_max` is a dictionary with the property name as key and the value above which the complexity is incremented. If the total complexity exceeds `max_complexity` the molecule is rejected. Example in yaml notation:: ProfileFilter: property_ranges: MolWt: [250, 500] FractionCSP3: [0, 1] AlogP: [-1, 4] NumRotatableBonds: &NumRotatableBonds_range [0, 10] :see: ProfileSettings """
[docs] class Settings(ProfileSettings): core_complexity: Dict[str, int] max_complexity: int = 1 large_ring_cutoff: int = 6
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) filtering.DESCRIPTORS_DICT[ 'LargeRingCount'] = lambda x: self.largeRingCount(x)
[docs] def largeRingCount(self, mol): large_ring_count = 0 for ring in mol.GetRingInfo().AtomRings(): if len(ring) > self.large_ring_cutoff: large_ring_count += 1 return large_ring_count
[docs] def validate(self, step): issues = super().validate(step) for prop in self.core_complexity: if prop not in filtering.DESCRIPTORS_DICT: issues.append( stepper.SettingsError( step, f'{prop} is not a valid core_complexity property')) if any(v < 0 for v in self.core_complexity.values()): issues.append( stepper.SettingsError( step, 'complexity_max values should be >= 0')) if self.large_ring_cutoff <= 2: issues.append( stepper.SettingsWarning( step, f'large_ring_cutoff of {self.large_ring_cutoff} is small' )) if self.max_complexity < 0: issues.append( stepper.SettingsError(step, 'max_complexity should be >= 0')) return issues
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._remainder = []
[docs] def validateSettings(self): return self.settings.validate(self)
[docs] def setUp(self): # the complexity properties that are not in the property_range dict self._remainder = [ prop for prop in self.settings.core_complexity if prop not in self.settings.property_ranges ]
def _getComplexityValue(self, prop, prop_value): if prop_value > self.settings.core_complexity.get(prop, INF): return 1 return 0
[docs] def mapFunction(self, mol): complexity = 0 # check the required property ranges keeping track of complexity for prop, prop_range in self.settings.property_ranges.items(): prop_value = filtering.DESCRIPTORS_DICT[prop](mol) if prop_value < prop_range[0] or prop_value > prop_range[1]: return complexity += self._getComplexityValue(prop, prop_value) if complexity > self.settings.max_complexity: return # get the core complexity ones we may not have computed yet for prop in self._remainder: prop_value = filtering.DESCRIPTORS_DICT[prop](mol) complexity += self._getComplexityValue(prop, prop_value) if complexity > self.settings.max_complexity: return yield mol
[docs]class FepFilter(MolMapStep): """ Only allow molecules through that are amenable to FEP calculations. A molecule is considered to be amenable if it is an acceptable perturbation with `settings.min_edges` other molecules in the `self.settings.ref_mols`. A perturbation is considered acceptable if the number of heavy atoms in the perturbation from the maximum common substructure (MCS) is less than or equal to `settings.max_hac_diff`. The settings contain the following parameters: - fep_references_file: the title-less SMILES file with SMILES in the 0th column (the way that Maestro exports them...) - ref_mols: the list of references (Chem.Mol objects) - min_edges: the minimum number of edges needed - max_hac_diff: the maximum number of heavy atoms not part of the MCS if the `fep_references_file` is defined, the `ref_mols` will be ignored. """
[docs] class Settings(parameters.CompoundParam): fep_references_file: stepper.StepperFile ref_mols: List[Chem.Mol] min_edges: int = 1 max_hac_diff: int = 10
def _getRefMols(self): ref_mols = self.settings.ref_mols file = self.settings.fep_references_file if file and os.path.isfile(file): ref_mols = [ m for m in Chem.SmilesMolSupplier(file, titleLine=0) if m ] return ref_mols
[docs] def validateSettings(self): issues = utils.validate_file( self, 'fep_references_file') # optional StepperFile ref_mols = self._getRefMols() if not ref_mols: issues.append( stepper.SettingsError(self, 'no reference molecules defined')) elif len(ref_mols) < self.settings.min_edges: issues.append( stepper.SettingsError(self, 'not enough references provided')) return issues
[docs] def setUp(self): super().setUp() self._ref_mols = self._getRefMols()
[docs] def mapFunction(self, mol): edge_count = 0 mol_hac = mol.GetNumHeavyAtoms() for ref in self._ref_mols: mcs = rdFMCS.FindMCS([ref, mol], ringMatchesRingOnly=True, completeRingsOnly=False, timeout=1) mcs_hac = Chem.MolFromSmarts(mcs.smartsString).GetNumHeavyAtoms() if (mol_hac - mcs_hac) <= self.settings.max_hac_diff: edge_count += 1 if edge_count >= self.settings.min_edges: yield mol break
[docs]class RangeFilter(MolMapStep): """ A filter that only passes the Mol part of a ScoredMol objects if the score is in the range determined by settings' min_value and max_value, with borders included. """ Input = ScoredMol InputSerializer = ScoredMolSerializer
[docs] class Settings(parameters.CompoundParam): min_value: float = -INF max_value: float = INF
[docs] def validateSettings(self): issues = [] if self.settings.min_value > self.settings.max_value: issues.append( stepper.SettingsError(self, 'min_value is larger than max_value')) if self.settings.min_value == -INF and self.settings.max_value == INF: issues.append( stepper.SettingsWarning(self, 'no filtering will be done')) return issues
[docs] def mapFunction(self, scored_mol): if self.settings.min_value <= scored_mol.score <= self.settings.max_value: yield scored_mol.mol
[docs]class ScoreFilter(MolMolWorkflow): """ A base class a for step consisting of a `self.SCORER_CLASS` instance followed by a `RangeFilter`. :cvar SCORER_CLASS: the class that will compute the value to be filtered :vartype SCORER_CLASS: ScorerStep """ SCORER_CLASS = NotImplementedError
[docs] class Settings(parameters.CompoundParam): score_filter: dict
[docs] def __init__(self, scorer_class=None, **kwargs): if scorer_class: self.SCORER_CLASS = scorer_class super().__init__(**kwargs)
[docs] def buildChain(self): scorer = self.SCORER_CLASS() utils.apply_config_settings_to_step(self.settings.score_filter, scorer) self.addStep(scorer) filter = RangeFilter() utils.apply_config_settings_to_step(self.settings.score_filter, filter) self.addStep(filter)
[docs]class ScoreFilterChain(MolMolWorkflow): """ A base class for chains consisting of `ScoreFilter` """
[docs] class Settings(parameters.CompoundParam): score_filters: List[Dict]
SCORER_CLASS = NotImplementedError
[docs] def validateSettings(self): issues = [] if len(self.settings.score_filters) == 0: issues.append( stepper.SettingsWarning(self, 'no filtering will be done')) return issues + super().validateSettings()
[docs] def buildChain(self): for score_filter in self.settings.score_filters: sf_step = ScoreFilter(self.SCORER_CLASS, score_filter=score_filter) self.addStep(sf_step)
[docs]class LigandMLScoreFilterChain(ScoreFilterChain): """ A variable length chain of LigandMLFilters, allowed to be empty. The configuration information is a list of `LigandMLScorer` settings combined with those of a `RangeFilter` for each ligand ML model to use. Example yaml file configuration:: LigandMLScoreFilterChain: score_filters: - ml_file: model1.qzip min_value: -3.0 max_value: 2.0 - ml_file: model2.qzip max_value: 200 """ SCORER_CLASS = scorers.LigandMLScorer
[docs]class MaxMolWtSettings(parameters.CompoundParam): """ A compound param that defines a maximum molecular weight. """ max_mol_wt: float = None
[docs]class MaxMolWtMixin: """ A mixin for steps that filter Chem.Mol objects by molecular weight. Expects the class to have a max_mol_wt setting. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) msg = 'The Settings should inherit from MaxMwtSettings' assert isinstance(self.settings, MaxMolWtSettings), msg
[docs] def validateMaxMolWtSettings(self): """ :return: A settings error if the max mol. weight setting is not properly set, None otherwise. :rtype: stepper.SettingsError or None """ max_mol_wt = self.settings.max_mol_wt if max_mol_wt is not None and max_mol_wt <= 0.0: return stepper.SettingsError( self, 'max_mol_wt should be greater than 0.0') return None
[docs] def hasAcceptableMolWt(self, mol): """ Check if the specified mol passes molecular weight checks. :param mol: Mol to be checked. :type mol: Chem.Mol :return: True if the mol passes molecular weight check, False otherwise. :rtype: bool """ max_mw = self.settings.max_mol_wt if max_mw is None: return True mw = Descriptors.MolWt(mol) return mw <= max_mw
[docs]class ScoredUniqueSmilesFilter(ScoredSmilesIOMixin, stepper.ReduceStep): """ A filter that yields the `ScoredSmiles` objects that have unique SMILES with the `settings.keep_lowest` determining which score should be retained if duplicated SMILES are encountered. """
[docs] class Settings(parameters.CompoundParam): keep_lowest: bool = True
[docs] def reduceFunction(self, inputs): outputs = dict() for scored_smiles in inputs: smiles, score = scored_smiles.smiles, scored_smiles.score old_score = outputs.get(smiles) if old_score is None: outputs[smiles] = score elif self.settings.keep_lowest: if score < old_score: outputs[smiles] = score else: if score > old_score: outputs[smiles] = score for smiles, score in outputs.items(): yield ScoredSmiles(smiles=smiles, score=score)
[docs]class SortedScoreFilter(ScoredSmilesIOMixin, stepper.ReduceStep): """ A filter that will return the first `max_out` number of scored SMILES, after they have been sorted with a sort order determined by the `reverse` setting. """
[docs] class Settings(parameters.CompoundParam): reverse: bool = False max_out: int = None
[docs] def reduceFunction(self, inputs): ordered = sorted(inputs, key=lambda x: x.score, reverse=self.settings.reverse) yield from ordered[:self.settings.max_out]