"""
Filter steps that take input Mols and only outputs Mols that meet certain
criteria.
NOTE::
You can import `UniqueSmilesFilter` from this module, which will
either be `InMemoryUniqueSmilesFilter` or `BQDedupeChain` (which will
filter using BigQuery) depending on whether you have the environment
variable SCHRODINGER_GCP_KEY defined.
You can generally use `UniqueSmilesFilter` for your workflows and users
will define SCHRODINGER_GCP_KEY as needed when they scale higher, but
if you know a part of your workflow will be low-volume then it would
be better to use `InMemoryUniqueSmilesFilter` instead.
All of the above also applies to `RandomSampleFilter`, which also has
an in-memory and bigquery variant.
"""
import copy
import os
from random import Random
from sys import maxsize
from typing import Dict
from typing import List
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import rdFMCS
from schrodinger import stepper
from schrodinger.application.pathfinder import filtering
from schrodinger.models import parameters
from schrodinger.structutils import filter
from schrodinger.structutils.smiles import STEREO_FROM_ANNOTATION_AND_GEOM
from schrodinger.structutils.smiles import SmilesGenerator
from . import bigquery_deduplication
from . import env_keys
from . import redshift_deduplication
from . import scorers
from . import utils
from .basesteps import LoggerMixin
from .basesteps import MaeMapStep
from .basesteps import MolMapStep
from .basesteps import MolMolWorkflow
from .basesteps import MolReduceStep
from .dataclasses import ScoredMol
from .dataclasses import ScoredMolSerializer
from .dataclasses import ScoredSmiles
from .dataclasses import ScoredSmilesIOMixin
INF = float('inf')
# FIXME: FILTER_TABLE be removed in AD-480 and replaced with a single param
FILTER_TABLE = (redshift_deduplication.RSTable
if env_keys.is_aws_service_available() else
bigquery_deduplication.BQTable)
[docs]class InMemoryUniqueRandomSampleFilter(stepper.Chain):
[docs] class Settings(parameters.CompoundParam):
n: int = None
seed: int = None
[docs] def buildChain(self):
self.addStep(InMemoryUniqueSmilesFilter())
self.addStep(
InMemoryRandomSampleFilter(n=self.settings.n,
seed=self.settings.seed))
[docs]class InMemoryUniqueSmilesFilter(MolReduceStep):
"""
Only allow molecules through whose canonical SMILES was not previously
encountered.
"""
[docs] def reduceFunction(self, mols):
seen_smiles = set()
for mol in mols:
smiles = Chem.MolToSmiles(mol)
if smiles not in seen_smiles:
seen_smiles.add(smiles)
yield mol
[docs]class InMemoryRandomSampleFilter(MolReduceStep):
"""
A filter that takes a random subsample of molecules. The sample size can
be set through the step's settings `n`.
Implementation of Algorithm R, but without knowing the size of sequence to
sample from. See https://en.wikipedia.org/wiki/Reservoir_sampling
"""
[docs] class Settings(parameters.CompoundParam):
n: int = None
seed: int = None
[docs] def validateSettings(self):
if self.settings.n is None:
return [stepper.SettingsError(self, 'n must be set')]
if self.settings.n <= 0:
return [stepper.SettingsError(self, 'n must be larger than 0')]
return []
[docs] def reduceFunction(self, mols):
settings = self.settings
random = Random(settings.seed)
mols = iter(mols) # to be sure we iterate once over the first max_len
max_rnd_int = max_len = self.settings.n
reservoir = []
for mol in mols:
reservoir.append(Chem.MolToSmiles(mol))
if len(reservoir) == max_len:
break
for mol in mols:
replace_idx = random.randint(0, max_rnd_int)
if replace_idx < max_len:
smi = Chem.MolToSmiles(mol)
reservoir[replace_idx] = smi
max_rnd_int += 1
for smi in reservoir:
yield Chem.MolFromSmiles(smi)
[docs]class CloudOptionMixin:
"""
A mixin that allows the task a choice between cloud based or in-memory.
"""
[docs] class Settings(parameters.CompoundParam):
use_cloud_services: bool = False
[docs] def validateSettings(self):
ret = super().validateSettings()
if (self.settings.use_cloud_services and
not env_keys.has_cloud_services()):
return ret + [
stepper.SettingsError(
self, 'use_cloud_services is set to True but no cloud '
'environment variable key is defined!')
]
return ret
[docs]class UniqueRandomSampleFilter(CloudOptionMixin, stepper.Chain):
[docs] class Settings(CloudOptionMixin.Settings):
table: FILTER_TABLE
n: int = None
seed: int = None
[docs] def buildChain(self):
if self.settings.use_cloud_services:
if env_keys.SCHRODINGER_GCP_KEY:
self.addStep(
bigquery_deduplication.BQDedupeAndRandomSampleFilter(
table=copy.deepcopy(self.settings.table),
n=self.settings.n))
elif env_keys.is_aws_service_available():
self.addStep(
redshift_deduplication.RSDeduplicateAndRandomSampleFilter(
table=copy.deepcopy(self.settings.table),
n=self.settings.n))
else:
raise RuntimeError('Missing cloud service environment keys')
else:
self.addStep(
InMemoryUniqueRandomSampleFilter(n=self.settings.n,
seed=self.settings.seed))
[docs]class UniqueSmilesFilter(CloudOptionMixin, stepper.Chain):
"""
A link step to determine whether cloud services are setup for SMILES
filtering or to default to an in-memory filter task.
"""
[docs] class Settings(CloudOptionMixin.Settings):
table: FILTER_TABLE
[docs] def buildChain(self):
if self.settings.use_cloud_services:
if env_keys.SCHRODINGER_GCP_KEY:
self.addStep(
bigquery_deduplication.BQUniqueSmilesFilter(
**self.settings.table.toDict()))
elif env_keys.is_aws_service_available():
self.addStep(
redshift_deduplication.RSUniqueSmilesFilter(
table=copy.deepcopy(self.settings.table)))
else:
raise RuntimeError('Missing cloud service environment keys')
else:
self.addStep(InMemoryUniqueSmilesFilter())
[docs]class RandomSampleFilter(CloudOptionMixin, stepper.Chain):
"""
A link step to determine whether cloud services are setup for SMILES
random sampling or to default to an in-memory sampling task.
"""
[docs] class Settings(CloudOptionMixin.Settings):
table: FILTER_TABLE
n: int = None
seed: int = None
[docs] def buildChain(self):
if self.settings.use_cloud_services:
if env_keys.SCHRODINGER_GCP_KEY:
self.addStep(
bigquery_deduplication.BQRandomSampleFilter(
table=copy.deepcopy(self.settings.table),
n=self.settings.n))
elif env_keys.is_aws_service_available():
self.addStep(
redshift_deduplication.RSRandomSampleFilter(
table=copy.deepcopy(self.settings.table),
n=self.settings.n))
else:
raise RuntimeError('Missing cloud service environment keys')
else:
self.addStep(
InMemoryRandomSampleFilter(n=self.settings.n,
seed=self.settings.seed))
[docs]class MaeUniqueSmilesFilter(MaeMapStep):
"""
Filter structures based on unique seen SMILES.
"""
[docs] def setUp(self):
super().setUp()
self._seen_smiles = set()
self._smiles_generator = SmilesGenerator(
STEREO_FROM_ANNOTATION_AND_GEOM, unique=True)
[docs] def mapFunction(self, st):
smiles = self._smiles_generator.getSmiles(st)
if smiles not in self._seen_smiles:
self._seen_smiles.add(smiles)
yield st
[docs]class SmartsFilter(MolMapStep):
"""
Only allow molecules that have a SMARTS substructure defined in settings.
"""
[docs] class Settings(parameters.CompoundParam):
core_smarts: str = None
[docs] def validateSettings(self):
return utils.validate_core_smarts(self, self.settings.core_smarts)
[docs] def setUp(self):
super().setUp()
self._core_smarts = Chem.MolFromSmarts(self.settings.core_smarts)
self._need_Hs = None
[docs] def mapFunction(self, mol):
if self._need_Hs is None:
self._need_Hs = utils.need_Hs(mol, self._core_smarts)
if self._need_Hs is None: # wasn't able to match
return
test_mol = Chem.AddHs(mol) if self._need_Hs else mol
if test_mol.HasSubstructMatch(self._core_smarts):
yield mol
[docs]class ChiralCenterCountFilter(MolMapStep):
"""
Only allow molecules through that have the number of chiral centers that
falls in the ranged determined by settings' min_value and max_value, with
the borders included.
"""
[docs] class Settings(parameters.CompoundParam):
min_value: int = 0
max_value: int = maxsize
[docs] def validateSettings(self):
issues = []
if self.settings.min_value > self.settings.max_value:
issues.append(stepper.SettingsError(self, 'min_value > max_value'))
return issues
[docs] def mapFunction(self, mol):
count = len(Chem.FindMolChiralCenters(mol, includeUnassigned=True))
if self.settings.min_value <= count <= self.settings.max_value:
yield mol
[docs]class ProductFilterMixin:
"""
A mixin that only allows molecules through whose SMARTS substructure count
passes all product filters defined in the settings.
If the settings has `cflt_file` defined, it will be considered to be the
path from which to create the `smarts_filter` to use.
Since the `filter.SmartsFilter` can `checkStructure` with either `Mol` or
`Structure` objects the behavior is completely determined by the class that
it is a mixin for, i.e., either an `MolMapStep` or `MaeMapStep`.
"""
[docs] class Settings(parameters.CompoundParam):
"""
:ivar smarts_filter: the smarts filter or None
:ivar cflt_file: if defined the source to
"""
cflt_file: stepper.StepperFile # the product filter file
def _setSmartsFilter(self):
self._smarts_filter = None
file = self.settings.cflt_file
if file and os.path.isfile(file):
self._smarts_filter = filter.SmartsFilter(filename=file)
[docs] def validateSettings(self):
issues = utils.validate_file(self, 'cflt_file', required=True)
if issues:
return issues
self._setSmartsFilter()
if self._smarts_filter is None or len(self._smarts_filter.filters) < 1:
issues.append(
stepper.SettingsError(self, 'no product filters defined'))
return issues
[docs] def setUp(self):
super().setUp()
self._setSmartsFilter()
self.logger.info(
f'{self.getStepId()} {len(self._smarts_filter.filters)} filters')
[docs] def mapFunction(self, molecule):
if self._smarts_filter.checkStructure(molecule):
yield molecule
[docs]class ProductFilter(ProductFilterMixin, MolMapStep):
"""
See `ProductFilterMixin`
"""
pass
[docs]class MaeProductFilter(ProductFilterMixin, MaeMapStep):
"""
See `ProductFilterMixin`
"""
pass
[docs]class ScoredSmilesProductFilter(ScoredSmilesIOMixin, ProductFilterMixin,
LoggerMixin, stepper.MapStep):
"""
A ProductFilter using scored smiles and inputs and outputs.
See `ProductFilterMixin`
"""
[docs] def mapFunction(self, scored_smiles):
mol = Chem.MolFromSmiles(scored_smiles.smiles)
if not mol:
self.logger.error(f'ERROR: {self.getStepId()} invalid SMILES:'
f' {scored_smiles.smiles}')
elif list(super().mapFunction(mol)):
yield scored_smiles
[docs]class PropertyFilter(MolMapStep):
"""
Only allows molecules through whose properties pass all property filters
defined in the settings.
If the settings has `filter_file` defined, it will be considered to be the
path from which to create the filters to be used.
If the filter filename has the .json extension, it will be interpreted as a
JSON filter file of the kind used by a few panels, including PathFinder;
that filter format is implemented by
`schrodinger.ui.qt.filter_dialog_dir.filter_core`.
A file without the .json extension is considered to be regular text with the
format of the filter file is that of what pathfinder and canvas filter uses,
e.g., lines like::
r_rdkit_TPSA > 30.0 < 150.0
i_rdkit_NumRotatableBonds < 10
i_rdkit_NumChiralCenters == 1
r_rdkit_MolWt > 150.0 < 575.0
"""
[docs] class Settings(parameters.CompoundParam):
filter_file: stepper.StepperFile
def _setPropertyFilter(self):
self._prop_filter = None
file = self.settings.filter_file
if file and os.path.isfile(file):
self._prop_filter = filtering.get_filter(file, smarts_filter=False)
def _getCheckCount(self):
if isinstance(self._prop_filter, filtering.JSONFilterAdapter):
return sum(
1 for c in self._prop_filter.filter_obj.criteria if c.checked)
return len(self._prop_filter.getPropertyNames())
[docs] def validateSettings(self):
issues = utils.validate_file(self, 'filter_file', required=True)
if issues:
return issues
self._setPropertyFilter()
if self._prop_filter is None or self._getCheckCount() < 1:
issues.append(
stepper.SettingsError(self, 'no property filters defined'))
return issues
[docs] def setUp(self):
super().setUp()
self._setPropertyFilter()
self._descriptor_names = self._prop_filter.getPropertyNames()
self.logger.info(f'{self.getStepId()}: {self._getCheckCount()} filters')
[docs] def mapFunction(self, mol):
filtering.add_descriptors(mol, self._descriptor_names, refs=None)
yield from self._prop_filter.filter([mol])
[docs]class ProfileSettings(parameters.CompoundParam):
"""
The `property_ranges` should be a dict where the key is the property name
(a string that exists in `filtering.DESCRIPTORS_DICT`). The value should be
the range of values allowed for that property where the ends are included in
what is considered to be acceptable.
"""
property_ranges: Dict[str, List[float]]
[docs] def validate(self, step):
"""
Validate the settings on behalf of a step.
:param step: stepper._BaseStep
:rtype: list[TaskError or TaskWarning]
"""
if not self.property_ranges:
return [stepper.SettingsWarning(step, 'no filtering will be done')]
issues = []
for prop, range in self.property_ranges.items():
if prop not in filtering.DESCRIPTORS_DICT:
issues.append(
stepper.SettingsError(step,
f'{prop} is not a valid property'))
if range[0] > range[1]:
issues.append(
stepper.SettingsError(step, f'{prop} has an invalid range'))
return issues
[docs]class ProfileFilter(MolMapStep):
"""
A product filter where the property profile is defined by the settings.
The `complexity_max` is a dictionary with the property name as key and the
value above which the complexity is incremented. If the total complexity
exceeds `max_complexity` the molecule is rejected.
Example in yaml notation::
ProfileFilter:
property_ranges:
MolWt: [250, 500]
FractionCSP3: [0, 1]
AlogP: [-1, 4]
NumRotatableBonds: &NumRotatableBonds_range [0, 10]
:see: ProfileSettings
"""
[docs] class Settings(ProfileSettings):
core_complexity: Dict[str, int]
max_complexity: int = 1
large_ring_cutoff: int = 6
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
filtering.DESCRIPTORS_DICT[
'LargeRingCount'] = lambda x: self.largeRingCount(x)
[docs] def largeRingCount(self, mol):
large_ring_count = 0
for ring in mol.GetRingInfo().AtomRings():
if len(ring) > self.large_ring_cutoff:
large_ring_count += 1
return large_ring_count
[docs] def validate(self, step):
issues = super().validate(step)
for prop in self.core_complexity:
if prop not in filtering.DESCRIPTORS_DICT:
issues.append(
stepper.SettingsError(
step,
f'{prop} is not a valid core_complexity property'))
if any(v < 0 for v in self.core_complexity.values()):
issues.append(
stepper.SettingsError(
step, 'complexity_max values should be >= 0'))
if self.large_ring_cutoff <= 2:
issues.append(
stepper.SettingsWarning(
step,
f'large_ring_cutoff of {self.large_ring_cutoff} is small'
))
if self.max_complexity < 0:
issues.append(
stepper.SettingsError(step,
'max_complexity should be >= 0'))
return issues
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._remainder = []
[docs] def validateSettings(self):
return self.settings.validate(self)
[docs] def setUp(self):
# the complexity properties that are not in the property_range dict
self._remainder = [
prop for prop in self.settings.core_complexity
if prop not in self.settings.property_ranges
]
def _getComplexityValue(self, prop, prop_value):
if prop_value > self.settings.core_complexity.get(prop, INF):
return 1
return 0
[docs] def mapFunction(self, mol):
complexity = 0
# check the required property ranges keeping track of complexity
for prop, prop_range in self.settings.property_ranges.items():
prop_value = filtering.DESCRIPTORS_DICT[prop](mol)
if prop_value < prop_range[0] or prop_value > prop_range[1]:
return
complexity += self._getComplexityValue(prop, prop_value)
if complexity > self.settings.max_complexity:
return
# get the core complexity ones we may not have computed yet
for prop in self._remainder:
prop_value = filtering.DESCRIPTORS_DICT[prop](mol)
complexity += self._getComplexityValue(prop, prop_value)
if complexity > self.settings.max_complexity:
return
yield mol
[docs]class FepFilter(MolMapStep):
"""
Only allow molecules through that are amenable to FEP calculations.
A molecule is considered to be amenable if it is an acceptable perturbation
with `settings.min_edges` other molecules in the `self.settings.ref_mols`.
A perturbation is considered acceptable if the number of heavy atoms in the
perturbation from the maximum common substructure (MCS) is less than or
equal to `settings.max_hac_diff`.
The settings contain the following parameters:
- fep_references_file: the title-less SMILES file with SMILES in the 0th
column (the way that Maestro exports them...)
- ref_mols: the list of references (Chem.Mol objects)
- min_edges: the minimum number of edges needed
- max_hac_diff: the maximum number of heavy atoms not part of the MCS
if the `fep_references_file` is defined, the `ref_mols` will be ignored.
"""
[docs] class Settings(parameters.CompoundParam):
fep_references_file: stepper.StepperFile
ref_mols: List[Chem.Mol]
min_edges: int = 1
max_hac_diff: int = 10
def _getRefMols(self):
ref_mols = self.settings.ref_mols
file = self.settings.fep_references_file
if file and os.path.isfile(file):
ref_mols = [
m for m in Chem.SmilesMolSupplier(file, titleLine=0) if m
]
return ref_mols
[docs] def validateSettings(self):
issues = utils.validate_file(
self, 'fep_references_file') # optional StepperFile
ref_mols = self._getRefMols()
if not ref_mols:
issues.append(
stepper.SettingsError(self, 'no reference molecules defined'))
elif len(ref_mols) < self.settings.min_edges:
issues.append(
stepper.SettingsError(self, 'not enough references provided'))
return issues
[docs] def setUp(self):
super().setUp()
self._ref_mols = self._getRefMols()
[docs] def mapFunction(self, mol):
edge_count = 0
mol_hac = mol.GetNumHeavyAtoms()
for ref in self._ref_mols:
mcs = rdFMCS.FindMCS([ref, mol],
ringMatchesRingOnly=True,
completeRingsOnly=False,
timeout=1)
mcs_hac = Chem.MolFromSmarts(mcs.smartsString).GetNumHeavyAtoms()
if (mol_hac - mcs_hac) <= self.settings.max_hac_diff:
edge_count += 1
if edge_count >= self.settings.min_edges:
yield mol
break
[docs]class RangeFilter(MolMapStep):
"""
A filter that only passes the Mol part of a ScoredMol objects if the score
is in the range determined by settings' min_value and max_value, with
borders included.
"""
Input = ScoredMol
InputSerializer = ScoredMolSerializer
[docs] class Settings(parameters.CompoundParam):
min_value: float = -INF
max_value: float = INF
[docs] def validateSettings(self):
issues = []
if self.settings.min_value > self.settings.max_value:
issues.append(
stepper.SettingsError(self,
'min_value is larger than max_value'))
if self.settings.min_value == -INF and self.settings.max_value == INF:
issues.append(
stepper.SettingsWarning(self, 'no filtering will be done'))
return issues
[docs] def mapFunction(self, scored_mol):
if self.settings.min_value <= scored_mol.score <= self.settings.max_value:
yield scored_mol.mol
[docs]class ScoreFilter(MolMolWorkflow):
"""
A base class a for step consisting of a `self.SCORER_CLASS` instance
followed by a `RangeFilter`.
:cvar SCORER_CLASS: the class that will compute the value to be filtered
:vartype SCORER_CLASS: ScorerStep
"""
SCORER_CLASS = NotImplementedError
[docs] class Settings(parameters.CompoundParam):
score_filter: dict
[docs] def __init__(self, scorer_class=None, **kwargs):
if scorer_class:
self.SCORER_CLASS = scorer_class
super().__init__(**kwargs)
[docs] def buildChain(self):
scorer = self.SCORER_CLASS()
utils.apply_config_settings_to_step(self.settings.score_filter, scorer)
self.addStep(scorer)
filter = RangeFilter()
utils.apply_config_settings_to_step(self.settings.score_filter, filter)
self.addStep(filter)
[docs]class ScoreFilterChain(MolMolWorkflow):
"""
A base class for chains consisting of `ScoreFilter`
"""
[docs] class Settings(parameters.CompoundParam):
score_filters: List[Dict]
SCORER_CLASS = NotImplementedError
[docs] def validateSettings(self):
issues = []
if len(self.settings.score_filters) == 0:
issues.append(
stepper.SettingsWarning(self, 'no filtering will be done'))
return issues + super().validateSettings()
[docs] def buildChain(self):
for score_filter in self.settings.score_filters:
sf_step = ScoreFilter(self.SCORER_CLASS, score_filter=score_filter)
self.addStep(sf_step)
[docs]class LigandMLScoreFilterChain(ScoreFilterChain):
"""
A variable length chain of LigandMLFilters, allowed to be empty.
The configuration information is a list of `LigandMLScorer` settings
combined with those of a `RangeFilter` for each ligand ML model to use.
Example yaml file configuration::
LigandMLScoreFilterChain:
score_filters:
- ml_file: model1.qzip
min_value: -3.0
max_value: 2.0
- ml_file: model2.qzip
max_value: 200
"""
SCORER_CLASS = scorers.LigandMLScorer
[docs]class MaxMolWtSettings(parameters.CompoundParam):
"""
A compound param that defines a maximum molecular weight.
"""
max_mol_wt: float = None
[docs]class MaxMolWtMixin:
"""
A mixin for steps that filter Chem.Mol objects by molecular weight.
Expects the class to have a max_mol_wt setting.
"""
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
msg = 'The Settings should inherit from MaxMwtSettings'
assert isinstance(self.settings, MaxMolWtSettings), msg
[docs] def validateMaxMolWtSettings(self):
"""
:return: A settings error if the max mol. weight setting is not
properly set, None otherwise.
:rtype: stepper.SettingsError or None
"""
max_mol_wt = self.settings.max_mol_wt
if max_mol_wt is not None and max_mol_wt <= 0.0:
return stepper.SettingsError(
self, 'max_mol_wt should be greater than 0.0')
return None
[docs] def hasAcceptableMolWt(self, mol):
"""
Check if the specified mol passes molecular weight checks.
:param mol: Mol to be checked.
:type mol: Chem.Mol
:return: True if the mol passes molecular weight check, False otherwise.
:rtype: bool
"""
max_mw = self.settings.max_mol_wt
if max_mw is None:
return True
mw = Descriptors.MolWt(mol)
return mw <= max_mw
[docs]class ScoredUniqueSmilesFilter(ScoredSmilesIOMixin, stepper.ReduceStep):
"""
A filter that yields the `ScoredSmiles` objects that have unique SMILES with
the `settings.keep_lowest` determining which score should be retained if
duplicated SMILES are encountered.
"""
[docs] class Settings(parameters.CompoundParam):
keep_lowest: bool = True
[docs] def reduceFunction(self, inputs):
outputs = dict()
for scored_smiles in inputs:
smiles, score = scored_smiles.smiles, scored_smiles.score
old_score = outputs.get(smiles)
if old_score is None:
outputs[smiles] = score
elif self.settings.keep_lowest:
if score < old_score:
outputs[smiles] = score
else:
if score > old_score:
outputs[smiles] = score
for smiles, score in outputs.items():
yield ScoredSmiles(smiles=smiles, score=score)
[docs]class SortedScoreFilter(ScoredSmilesIOMixin, stepper.ReduceStep):
"""
A filter that will return the first `max_out` number of scored SMILES,
after they have been sorted with a sort order determined by the `reverse`
setting.
"""
[docs] class Settings(parameters.CompoundParam):
reverse: bool = False
max_out: int = None
[docs] def reduceFunction(self, inputs):
ordered = sorted(inputs,
key=lambda x: x.score,
reverse=self.settings.reverse)
yield from ordered[:self.settings.max_out]