Source code for schrodinger.application.bioluminate.antibody.io.csv_reader

"""
Utilities to read and validate data from a CSV file containing antibody sequence
information.
"""

from contextlib import contextmanager
from dataclasses import dataclass
import csv
import enum

from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Tuple

from schrodinger.utils import csv_unicode
from schrodinger.utils import fileutils

try:
    from schrodinger.application.prime.packages import antibody
except ImportError:
    antibody = None


[docs]class Headers(enum.Enum): """ Column headers for all modeling modes. """ NAME = 'NAME' LC = 'LC' HC = 'HC' HC1 = HC + '1' LC1 = LC + '1' HC2 = HC + '2' LC2 = LC + '2' EXTRA_COLS = enum.auto()
NON_SEQ_HEADERS = (Headers.NAME, Headers.EXTRA_COLS) SEQ_HEADERS = tuple( header for header in Headers if header not in NON_SEQ_HEADERS)
[docs]def get_valid_header_groups_string() -> str: """ Return a human-readable string containing all groups of valid headers defined in each validator e.g. "[Name, HC], [Name, HC, LC]" """ header_enum_groups = (cls.valid_headers for cls in _VALIDATORS) header_text_groups = ([header.value for header in header_group] for header_group in header_enum_groups) stringified_groups = ( str(header_group) for header_group in header_text_groups) return ', '.join(stringified_groups)
[docs]class ModelingMode(enum.Enum): """ General modes of antibody modeling that can be performed. """ SINGLE_DOMAIN = enum.auto() MONOSPECIFIC = enum.auto() BISPECIFIC = enum.auto()
# ============================================================================== # Errors # ==============================================================================
[docs]class DataViolation(enum.Enum): """ Different kinds of invalid data that can exist in a given CSV file. """ NAME = 'invalid name' HC_SEQ = 'invalid heavy chain sequence' LC_SEQ = 'invalid light chain sequence' NUM_COLS = 'invalid number of columns' HEADERS = 'invalid headers'
[docs]class BaseInvalidAntibodyCSVError(Exception): """ Base class for exceptions relating to invalid antibody csv files. Subclasses should define the message to display. """
[docs] def __init__(self, csv_file: str): super().__init__() self._csv_file = csv_file
[docs]class InvalidFileTypeError(BaseInvalidAntibodyCSVError): """ Exception to raise when the antibody csv file is the incorrect file type. """ def __str__(self): return (f'Expected the supplied antibody file {self._csv_file} to be a ' f'.csv file')
[docs]class InvalidCSVFormattingError(BaseInvalidAntibodyCSVError): """ Exception to raise when the antibody csv file is not formatted with valid headers. """ def __str__(self): """ Format all valid Headers enum combinations into strings and then return a message stating that antibody CSV headers must to one of those header combinations. """ valid_headers = get_valid_header_groups_string() return (f'Expected the supplied antibody file {self._csv_file} to have ' 'exactly one of the following sequences of headers: ' f'{valid_headers}. Note that the column order must exactly ' ' match.')
[docs]class InvalidCSVLengthError(BaseInvalidAntibodyCSVError): def __str__(self): return (f'Expected the supplied antibody file {self._csv_file} to have ' 'at least one non-header row.')
[docs]class InvalidRowError(Exception): """ Exception to raise when a particular row is invalid. """
[docs] def __init__(self, data_violations: List[DataViolation], row_num: int): intro = (f'Row {row_num} of the supplied file contains invalid ' f'information: ') data_violations = [violation.value for violation in data_violations] data_violations_str = ', '.join(data_violations) super().__init__(intro + data_violations_str)
# ============================================================================== # Private Helpers # ============================================================================== def _csv_headers_are_valid(csv_headers: List[str], valid_headers: List[Headers]) -> bool: """ Return whether the supplied csv headers match the order and spelling of the valid headers. Note that the order is predefined by design in order to minimize processing on our end. :param csv_headers: The headers of a CSV file. :param valid_headers: Header enums to compare with the CSV headers. """ cleaned_headers = [header.strip().upper() for header in csv_headers] valid_headers = [header.value for header in valid_headers] return cleaned_headers == valid_headers def _get_name_violation(name: str) -> Optional[DataViolation]: if not fileutils.is_valid_jobname(name): return DataViolation.NAME def _get_num_columns_violation( extra_vals: Optional[List[str]]) -> Optional[DataViolation]: """ Return a number of columns data violation if there are any extra columns. Note that having too few cells filled out in a particular row is not considered a number of columns violation since those cells get filled with a predetermined value that will be validated elsewhere. """ if extra_vals: return DataViolation.NUM_COLS def _get_heavy_chain_violation(heavy_seq: str) -> Optional[DataViolation]: if not antibody.heavy_chain_seq_is_valid(heavy_seq): return DataViolation.HC_SEQ def _get_light_chain_violation(light_seq: str) -> Optional[DataViolation]: if not antibody.light_chain_seq_is_valid(light_seq): return DataViolation.LC_SEQ # ============================================================================== # Validators # ============================================================================== @dataclass(frozen=True) class _AbstractAntibodyCSVValidator: """ A class the ability to validate row data from a CSV file with antibody sequences. Instance variables are immutable. Subclasses should never be instantiated directly. Instead, use `get_reader()` to get a reader instance that is equipped with the correct validator. Subclasses must define only `modeling_mode` and `validators_by_header`. """ modeling_mode: ClassVar[ModelingMode] = NotImplemented validators_by_header: ClassVar[Dict[Headers, Callable]] = NotImplemented valid_headers: ClassVar[Tuple[Headers]] = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.valid_headers = tuple( header for header in cls.validators_by_header.keys() if header is not Headers.EXTRA_COLS) def determine_data_violations( self, row_data: Dict[Headers, str], validate_seqs: bool = False) -> List[DataViolation]: """ Return all data violations associated with the given row data. :param row_data: The data contained in a single CSV row. :param validate_seqs: Whether to also check the light and heavy chain seqs. """ violations = [] for header, validator in self.validators_by_header.items(): if not validate_seqs and header in SEQ_HEADERS: continue data = row_data.get(header) if violation := validator(data): violations.append(violation) return violations def validate_row(self, row_data: Dict[Headers, str], row_num: int, validate_seqs: bool = False): """ Raise an error if any information in the supplied row is invalid. """ if violations := self.determine_data_violations(row_data, validate_seqs): raise InvalidRowError(violations, row_num) class _SingleDomainAntibodyCSVValidator(_AbstractAntibodyCSVValidator): modeling_mode = ModelingMode.SINGLE_DOMAIN validators_by_header = { Headers.EXTRA_COLS: _get_num_columns_violation, Headers.NAME: _get_name_violation, Headers.HC: _get_heavy_chain_violation, } class _MonospecificAntibodyCSVValidator(_AbstractAntibodyCSVValidator): modeling_mode = ModelingMode.MONOSPECIFIC validators_by_header = { Headers.EXTRA_COLS: _get_num_columns_violation, Headers.NAME: _get_name_violation, Headers.HC: _get_heavy_chain_violation, Headers.LC: _get_light_chain_violation } class _BispecificAntibodyCSVValidator(_AbstractAntibodyCSVValidator): modeling_mode = ModelingMode.BISPECIFIC validators_by_header = { Headers.EXTRA_COLS: _get_num_columns_violation, Headers.NAME: _get_name_violation, Headers.HC1: _get_heavy_chain_violation, Headers.LC1: _get_light_chain_violation, Headers.HC2: _get_heavy_chain_violation, Headers.LC2: _get_light_chain_violation } # TODO: add bispecifics back in BIOLUM-4664 _VALIDATORS = (_SingleDomainAntibodyCSVValidator(), _MonospecificAntibodyCSVValidator()) class _AntibodyCSVReader(csv.DictReader): """ A CSV DictReader that uses an antibody CSV validator to potentially validate the data in each row of the file. The validator should be supplied by `get_reader()`. """ def __init__(self, validator, *args, **kwargs): self._validator: _AbstractAntibodyCSVValidator = validator super().__init__(*args, fieldnames=self._validator.valid_headers, restkey=Headers.EXTRA_COLS, restval='', **kwargs) def validate_row(self, row_data: Dict[Headers, str], row_num: int, validate_seqs: bool): self._validator.validate_row(row_data, row_num, validate_seqs) def determine_data_violations( self, row_data: Dict[Headers, str]) -> List[DataViolation]: return self._validator.determine_data_violations(row_data) # ============================================================================== # Public API # ==============================================================================
[docs]@contextmanager def get_reader( csv_file: str) -> Optional[Generator[_AntibodyCSVReader, None, None]]: """ A context manager that returns an antibody CSV reader equipped with the appropriate validator class given the supplied CSV file. Raises an error if the given file does not meet the standards for proper parsing. :param csv_file: A CSV file with antibody sequences. """ if not fileutils.is_csv_file(csv_file): raise InvalidFileTypeError(csv_file) with csv_unicode.reader_open(csv_file) as ab_csv_handle: dict_reader = csv.DictReader(ab_csv_handle) for validator in _VALIDATORS: if _csv_headers_are_valid(dict_reader.fieldnames, validator.valid_headers): yield _AntibodyCSVReader(validator, f=ab_csv_handle) return # csv columns do not match any of the parser classes raise InvalidCSVFormattingError(csv_file)