Source code for schrodinger.application.steps.redshift_deduplication

"""
Steps for deduplicating and random sampling data through AWS (currently limited
to use with Mols).
"""
import gzip
import time
import itertools
import os
import re
import uuid

import more_itertools

from schrodinger.application.steps import basesteps
from schrodinger.application.steps import env_keys
from schrodinger.application.steps import utils
from schrodinger.models import parameters
from schrodinger.stepper.cloud import aws_client
from schrodinger import stepper
from schrodinger.stepper import logger

from .basesteps import MolMolMixin

try:
    import boto3
except ImportError:
    boto3 = None

# ==============================================================================
# Error messages
# ==============================================================================
INSUFFICIENT_DATABASE_SETTINGS = (
    'Cannot create new Redshift table with insufficient database settings:'
    '\ncluster_id={0}\ndatabase={1}\ndatabase_user={2}')
MISSING_BUCKET = ('AWS Redshift requires use of S3 buckets and bucket name '
                  'must be set through environment variables. See '
                  '`schrodinger.applciation.steps.env_keys` for more details.')
QUERY_ERROR = "Errors while running query {0}\n{1}"
INVALID_TABLE_NAME = ("Invalid table name specified. Table names must only "
                      "contain alphanumeric values and underscores, and must "
                      "also begin with a non-digit value.")

# ==============================================================================
# CONSTANTS
# ==============================================================================

REDSHIFT_CLIENT = None
S3_CLIENT = None
_MAX_FILES = 1000
# Default maximum size of one file when exporting from Redshift in MB
_MAX_FILE_SIZE = 5.0  # MB

# ==============================================================================
# AWS Clients
# ==============================================================================


def _get_redshift_client():
    global REDSHIFT_CLIENT
    if REDSHIFT_CLIENT is None:
        REDSHIFT_CLIENT = aws_client.get_client('redshift-data')
    return REDSHIFT_CLIENT


def _get_s3_client():
    global S3_CLIENT
    if S3_CLIENT is None:
        S3_CLIENT = aws_client.get_client('s3')
    return S3_CLIENT


# ==============================================================================
# REDSHIFT CONNECTION METHODS
# ==============================================================================


def _get_service_credentials():
    """
    Retrieves aws credentials for authentication across services
    (ex: Redshift <-> S3) by checking the following settings in order:
    1. a valid IAM Role set through environment variable.
    2. Access keys set through environment variables.
    3. Credentials set under ~/.aws with given profile set through environment
    variable.

    :return: credentials formatted according to redshift service query
        statement structure.
    :rtype: str
    """
    if env_keys.REDSHIFT_S3_IAM_ROLE:
        return f"IAM_ROLE '{env_keys.REDSHIFT_S3_IAM_ROLE}'"
    elif env_keys.SCHRODINGER_AWS_KEY:
        return f"""
            ACCESS_KEY_ID '{env_keys.SCHRODINGER_AWS_KEY}'
            SECRET_ACCESS_KEY '{env_keys.SCHRODINGER_AWS_SECRET_KEY}'"""
    else:
        # note that this is typically used when access keys are set under ~/.aws
        aws_credentials = aws_client.get_credentials()
        return f"""
            ACCESS_KEY_ID '{aws_credentials.access_key}'
            SECRET_ACCESS_KEY '{aws_credentials.secret_key}'
            SESSION_TOKEN '{aws_credentials.token}'"""


def _execute_redshift_statement(sql):
    """
    Execute a redshift sql query with appropriate database connection settings.

    :param sql: query to execute
    :type sql: str

    :return: query ID
    :rtype: str
    """
    response = _get_redshift_client().execute_statement(
        ClusterIdentifier=env_keys.REDSHIFT_CLUSTER_ID,
        Database=env_keys.REDSHIFT_DATABASE,
        DbUser=env_keys.REDSHIFT_DB_USER,
        Sql=sql)
    return response['Id']


def _wait_for_redshift_query(query_id, raise_error=True):
    """
    Wait for redshift query to complete. Raises RuntimeError if query fails to
    execute successfully if `raise_error` is enabled.

    :param query_id: unique ID of query to monitor
    :type query_id: str

    :param raise_error: whether to raise an exception if the query returns an
        error; default behavior is True (to raise an exception).
    :type raise_error: bool

    :return: the query job response
    :rtype: dict
    """
    client = _get_redshift_client()
    # sleep till query finishes
    while _get_query_status(query_id) not in ['FINISHED', 'FAILED']:
        time.sleep(5)
    # check for errors
    response = client.describe_statement(Id=query_id)
    if response.get('Error') and raise_error:
        msg = QUERY_ERROR.format(response['QueryString'], response['Error'])
        raise RuntimeError(msg)
    return response


def _get_query_status(query_id):
    """
    Helper method to retrieve query's status by ID.

    :param query_id: unique ID of query to retrieve status
    :type query_id: str

    :return: 'SUBMITTED'|'PICKED'|'STARTED'|'FINISHED'|'ABORTED'|'FAILED'|'ALL'
    :rtype: str
    """
    client = _get_redshift_client()
    return client.describe_statement(Id=query_id)['Status']


[docs]def run_query(sql, raise_error=True): """ Helper method to execute an SQL query and wait for its completion. :param sql: query to execute :type sql: str :param raise_error: whether to raise an exception for failed query - see `_wait_for_redshift_query` docstring for more details. :type raise_error: bool :return: the query job response :rtype: dict """ query_id = _execute_redshift_statement(sql) return _wait_for_redshift_query(query_id, raise_error)
# ============================================================================== # REDSHIFT DATABASE METHODS # ============================================================================== def _create_table(table_id): """ Creates a new table under the default database set during redshift execution settings. :param table_id: unique table ID :type table_id: str """ logger.debug(f"creating table... {table_id}") sql = f'CREATE TABLE "{table_id}" (data varchar(max));' _execute_redshift_statement(sql) def _drop_table(table_id): """ Requested table is deleted. :param table_id: table to delete by ID. :type table_id: str """ logger.debug(f"dropping table... {table_id}") sql = f'DROP TABLE "{table_id}";' _execute_redshift_statement(sql) def _table_exists(table_id): """ Check if table exists in the database. :param table_id: table to check status for. :type table_id: str :return: whether the requested table exists in database. :rtype: bool """ sql = f'SELECT EXISTS (SELECT 1 FROM "{table_id}");' response = run_query(sql, raise_error=False) if response.get('Error'): err_msg = response.get('Error') if f'relation "{table_id.lower()}" does not exist' in err_msg: return False # everything else should raise an error raise RuntimeError(QUERY_ERROR.format(sql, err_msg)) return True def _deduplicate_table(table_id, max_file_size=_MAX_FILE_SIZE): """ Deduplicates the requested table and exports to s3. :param table_id: table to deduplicate. :type table_id: str :param max_file_size: maximum file size for exporting before batching. :type max_file_size: float """ query_str = f""" SELECT data FROM \"{table_id}\" GROUP BY data""" destination = _execute_unload_query(table_id, query_str, max_file_size) logger.debug(f"deduplicated table {table_id} and exported to: " f"{destination}") def _random_sample_table(table_id, n, max_file_size=_MAX_FILE_SIZE): """ Samples the requested table randomly and exports the results to s3. :param table_id: table to randomly sample. :type table_id: str :param n: number of samples requested :type n: int :param max_file_size: maximum file size for exporting before batching. :type max_file_size: float """ query_str = f""" SELECT data FROM \"{table_id}\" WHERE RANDOM() < {n}/(SELECT COUNT(*) FROM \"{table_id}\")::float""" destination = _execute_unload_query(table_id, query_str, max_file_size) logger.debug(f"randomly sampled {n} entries from table {table_id} and" f"exported to: {destination} ") def _deduplicate_and_random_sample_table(table_id, n, max_file_size=_MAX_FILE_SIZE): """ The requested table is first deduplicated, then randomly sampled for `n` entries, with the results exported to s3. :param table_id: table to randomly sample. :type table_id: str :param n: number of samples requested :type n: int :param max_file_size: maximum file size for exporting before batching. :type max_file_size: float """ query_str = f""" SELECT data FROM \"{table_id}\" GROUP BY data HAVING RANDOM() < {n}/(SELECT COUNT(DISTINCT data) FROM \"{table_id}\")::float""" destination = _execute_unload_query(table_id, query_str, max_file_size) logger.debug(f"deduplicated and random sampled table {table_id} and " f"exported to: {destination}") def _export_table(table_id, max_file_size=_MAX_FILE_SIZE): """ Generic export call to unload the table into s3. :param table_id: table to export. :type table_id: str :param max_file_size: maximum file size for exporting before batching. :type max_file_size: float """ query_str = f""" SELECT * FROM \"{table_id}\"""" destination = _execute_unload_query(table_id, query_str, max_file_size) logger.debug(f"exported table.. {table_id} to: {destination}") def _execute_unload_query(table_id, query, max_file_size): """ Helper method to add s3 destination and authentication credentials to query. :param table_id: table to unload into s3. :type table_id: str :param query: SQL query. :type query: str :param max_file_size: maximum file size for exporting before batching. :type max_file_size: float :return: newly created s3 folder where the results are exported. :rtype: str """ bucket_name = env_keys.S3_BUCKET_NAME destination = f's3://{bucket_name}/{table_id}_output/' query_str = f""" UNLOAD ('{query}') TO '{destination}' MAXFILESIZE {max_file_size} MB {_get_service_credentials()};""" run_query(query_str) return destination def _import_table_from_s3(table_id, s3_path): """ Import data into a redshift table by copying the requested data from s3 to redshift. :param table_id: the newly generated table inside the redshift database. :type table_id: str :param s3_path: the absolute path to an s3 folder to copy data from. :type s3_path: str """ origin = s3_path query_str = f""" COPY \"{table_id}\" FROM '{origin}' {_get_service_credentials()} CSV GZIP;""" run_query(query_str) logger.debug(f"imported data to... {table_id}") # ============================================================================== # S3 BUCKET METHODS # ============================================================================== def _upload_to_s3(gen, serializer, s3_folder, chunk_size=100_000): """ Upload the given data to s3 under a new folder named after `s3_folder`. :param gen: iterable of entries to upload to s3 - must be serializable to string. :type gen: iter :param serializer: should be able to serializer data entries into strings. :type serializer: Serializer :param s3_folder: used as the folder name in s3. :type s3_folder: str :param chunk_size: number of lines to upload per file. :type chunk_size: int """ s3_client = _get_s3_client() bucket_name = env_keys.S3_BUCKET_NAME output_generator = more_itertools.peekable(gen) def get_chunk_of_rows(gen): return [ serializer.toString(output) for output in itertools.islice(gen, chunk_size) ] chunk_idx = 0 suffix = str(uuid.uuid4())[:8] while rows_to_upload := get_chunk_of_rows(output_generator): fname = f'{chunk_idx}_{suffix}.csv.gz' content = gzip.compress('\n'.join(rows_to_upload).encode()) s3_client.put_object(Body=content, Bucket=bucket_name, Key='/'.join([s3_folder, fname])) chunk_idx += 1 destination = get_s3_absolute_path(s3_folder, s3_file='') logger.debug(f"uploaded data to s3 bucket... {destination}") def _download_from_s3(table_id): """ All s3 files containing the prefix of `table_id` is downloaded. A line by line iterator is returned. :param table_id: download s3 files containing this prefix in their names. :type table_id: str :return: an iterator yielding per line. :rtype: iter[str] """ output_files = _get_s3_outputs(table_id) for output_file in output_files: for line in _get_s3_file(output_file): yield line def _get_s3_outputs(table_id): return _list_s3_folder(f'{table_id}_output/') def _list_s3_folder(prefix): """ List all files under an s3 folder. :param prefix: filter files containing the following prefix. :type prefix: str :return: iterable of s3 file names. :rtype: iter[str] """ s3_client = _get_s3_client() next_token = None while True: args = { 'Bucket': env_keys.S3_BUCKET_NAME, 'Prefix': prefix, } if next_token: args['ContinuationToken'] = next_token response = s3_client.list_objects_v2(**args) for entry in response['Contents']: yield entry['Key'] next_token = response.get('NextContinuationToken') if not response['IsTruncated']: break def _get_s3_file(filepath): """ Get s3 file contents. :param filepath: full file path to s3 object to obtain file content. :type filepath: str :return: an iterator over the lines in requested file. :rtype: iter[str] """ s3_client = _get_s3_client() response = s3_client.get_object(Bucket=env_keys.S3_BUCKET_NAME, Key=filepath) for line in response['Body'].iter_lines(): yield line.decode() def _delete_bucket_folder(s3_folder): """ Remove all files within the folder `s3_folder`, which is technically all s3 objects holding a prefix of `s3_folder`. :param s3_folder: files containing this prefix. :type s3_folder: str """ s3_client = _get_s3_client() bucket_name = env_keys.S3_BUCKET_NAME files = _list_s3_folder(s3_folder) num_files = 0 for batch in more_itertools.chunked(files, _MAX_FILES): s3_client.delete_objects( Bucket=bucket_name, Delete={'Objects': [{ 'Key': file } for file in batch]}) num_files += len(batch) logger.debug(f'deleted {num_files} files from bucket: {bucket_name}')
[docs]def get_s3_absolute_path(s3_folder, s3_file=None): components = ['s3:/', env_keys.S3_BUCKET_NAME, s3_folder] if s3_file: components.append(s3_file) return '/'.join(components)
# ============================================================================== # SETTINGS # ==============================================================================
[docs]class RSTable(parameters.CompoundParam): table_name: str = None
[docs] def getFullTableId(self): if not self.table_name: raise ValueError("Table name is not specified.") # Note: step names are typically used for table names so as a courtesy # we make a simple replacement name = self.table_name.replace('.', '_') if env_keys.CLOUD_ARTIFACT_PREFIX: name = '/'.join([env_keys.CLOUD_ARTIFACT_PREFIX, name]) if not re.match(r'^[a-zA-Z][\w/]+$', name): raise ValueError(f'{INVALID_TABLE_NAME} - {name=}') return name
[docs]class S3Folder(parameters.CompoundParam): folder_name: str = None
[docs] def getAbsolutePath(self): """ :return: the absolute path to folder in s3, e.g. s3://my_bucket/my_folder :rtype: str """ return get_s3_absolute_path(self.folder_name)
[docs]class S3File(parameters.CompoundParam): filename: str = None
[docs]class RSTableExportSettings(parameters.CompoundParam): max_file_size: float = 5 # MB || range: min=5MB to max=6.2GB
[docs]class RSFilterSettings(parameters.CompoundParam): table: RSTable s3_folder: S3Folder
# ============================================================================== # STEPS # ============================================================================== class _UploadToS3Step(basesteps.UploadStep): """ First part of the upload step works by uploading the input mols serialized to strings to s3. """ Settings = S3Folder Output = S3Folder def reduceFunction(self, inps): s3_folder = self.settings.folder_name _upload_to_s3(inps, self._getInputSerializer(), s3_folder) yield self.settings class _ExportFromS3ToRS(stepper.ReduceStep): """ Second part of the upload step works by copying over the S3 folder to the redshift database under the requested table ID. """ Settings = S3Folder Input = S3Folder Output = RSTable def reduceFunction(self, s3_folders): s3_paths = set([inp.getAbsolutePath() for inp in s3_folders]) for path in s3_paths: rs_table = self._s3FolderToRSTable(path) _import_table_from_s3(rs_table.getFullTableId(), path) yield rs_table def _s3FolderToRSTable(self, s3_path): return RSTable(table_name=s3_path.split('/')[-1]) class _EnumerateS3Folder(stepper.MapStep): """ This step only maps a given table name to it's appropriate folder in S3 and enumerates over the list of files there that were batch exported from the previous filtering step. """ Input = RSTable Output = S3File def mapFunction(self, table): table_id = table.getFullTableId() for batch_file in _get_s3_outputs(table_id): yield S3File(filename=batch_file) class _DownloadFromS3Step(basesteps.DownloadStep): """ Download results by looking for the newly created folder in s3 by either the deduplication or random sampling step. The output lines are serialized into mols. """ Input = S3File def mapFunction(self, inp): op_serializer = self.getOutputSerializer() for line in _get_s3_file(inp.filename): yield op_serializer.fromString(line) class _DeduplicateStep(basesteps.TableReduceStep): """ Deduplicates given Redshift table and batch exports the results to S3. """ Settings = RSTableExportSettings Input = RSTable Output = RSTable def _actOnTable(self, table_id): _deduplicate_table(table_id, self.settings.max_file_size) class _RandomSampleStep(basesteps.TableReduceStep): """ Randomly samples given Redshift table and batch exports the results to S3. """ class Settings(RSTableExportSettings): n: int = 5000 Input = RSTable Output = RSTable def _actOnTable(self, table_id): _random_sample_table(table_id, self.settings.n, self.settings.max_file_size) class _DeduplicateAndRandomSampleStep(_RandomSampleStep): """ Deduplicates and randomly samples given Redshift table and batch exports the results to S3. """ def _actOnTable(self, table_id): _deduplicate_and_random_sample_table(table_id, self.settings.n, self.settings.max_file_size) class _DropTableStep(basesteps.TableReduceStep): """ Drops the requested Redshift table, and outputs the table ID. """ Input = RSTable Output = RSTable def _actOnTable(self, table_id): _drop_table(table_id) # ============================================================================== # CHAINS # ==============================================================================
[docs]class RSFilter(MolMolMixin, basesteps.CloudFilterChain): """ Generic Redshift filter with table setup and validation defined. Classes inheriting from `RSFilter` need to define `addFilterSteps` for filter steps. """ Settings = RSFilterSettings
[docs] def setUp(self): super(RSFilter, self).setUp() self[0].setSettings(**self.settings.s3_folder.toDict()) # This is a bit of a hack to fix a bug when running this chain # with two levels of jobcontrol. See AD-359 self._setConfig(self._getCanonicalizedConfig())
def _setUpTable(self): table = self.settings.table if table.table_name is None: name = utils.generate_stepid_and_random_suffix(self) table.table_name = name.replace('-', '_') table_id = table.getFullTableId() _create_table(table_id) self.settings.s3_folder.folder_name = table_id def _validateTable(self): errs = [] cluster_id = env_keys.REDSHIFT_CLUSTER_ID database = env_keys.REDSHIFT_DATABASE db_user = env_keys.REDSHIFT_DB_USER if not (cluster_id and database and db_user): errs.append( stepper.SettingsError( self, INSUFFICIENT_DATABASE_SETTINGS.format( cluster_id, database, db_user))) if not env_keys.S3_BUCKET_NAME: errs.append(stepper.SettingsError(self, MISSING_BUCKET)) return errs
[docs] def buildChain(self): self.addStep(_UploadToS3Step(**self.settings.s3_folder.toDict())) self.addStep(_ExportFromS3ToRS()) self.addFilterSteps() self.addDropTableStepInProduction() self.addStep(_EnumerateS3Folder()) self.addStep(_DownloadFromS3Step())
[docs] def addFilterSteps(self): raise NotImplementedError
[docs] def addDropTableStepInProduction(self): if int(os.environ.get('SCHRODINGER_STEPPER_DEBUG', 0)): return self.addStep(_DropTableStep())
[docs]class RSUniqueSmilesFilter(RSFilter): """ A Chain that takes in Mol's, uploads them to Redshift, and deduplicates them. To use, set the table name you'd like to use in the step settings. """
[docs] def addFilterSteps(self): self.addStep(_DeduplicateStep())
[docs]class RSRandomSampleFilter(RSFilter): """ A Chain that takes in Mol's, uploads them to Redshift, and outputs a random sample of them. To use, set the table name you'd like to use in the step settings. The settings also has a `n` setting for determining roughly how many rows should be sampled. Note that this is an approximate number and a few more or less may be output. """
[docs] class Settings(RSFilterSettings): n: int = 5000
[docs] def addFilterSteps(self): self.addStep(_RandomSampleStep(n=self.settings.n))
[docs]class RSDeduplicateAndRandomSampleFilter(RSRandomSampleFilter): """ Same as RSRandomSampleFilter except the data is deduplicated before randomly sampled. """
[docs] def addFilterSteps(self): self.addStep(_DeduplicateAndRandomSampleStep(n=self.settings.n))