Source code for schrodinger.application.steps.bigquery_deduplication

"""
Steps for deduplicating and random sampling data (currently limited to
use with Mols).

See the chains at the bottom of the module for the steps you'll most likely want
to use.
"""
import copy
import time
import itertools
import os
import uuid

import more_itertools

from schrodinger.application.steps import basesteps
from schrodinger.models import parameters
from schrodinger import stepper
from schrodinger.stepper import logger

from .basesteps import MolMolMixin

try:
    from google.auth import compute_engine
    from google.cloud import bigquery
    from google.cloud import exceptions
    from google.cloud import storage
    from google.oauth2 import service_account
except ImportError:
    compute_engine = None
    bigquery = None
    storage = None
    exceptions = None

# Scopes
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"

# Error messages
NO_DATASET = 'No bigquery dataset set'
_NO_PROJECT_ERR_MSG = (
    "No bigquery project is defined for this run. Set "
    "an environment variable for SCHRODINGER_GCP_PROJECT and try again.")

#===============================================================================
# BigQuery Functions
#===============================================================================
PROJECT = os.environ.get('SCHRODINGER_GCP_PROJECT')
KEY_PATH = os.environ.get('SCHRODINGER_GCP_KEY')  # Service account key
BQ_CLIENT = None


def _create_table(table_id):
    table_id = _get_fully_qualified_table_id(table_id)
    client = _get_bq_client()
    schema = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]

    table = bigquery.Table(table_id, schema=schema)
    client.create_table(table)
    logger.debug(f"creating table... {table_id}")


def _is_bigquery_enabled() -> bool:
    return bool(os.environ.get('SCHRODINGER_GCP_ENABLED', False))


def _wait_for_query_job(job):
    result = job.result()
    if job.errors:
        err_str = (f"Errors while running query {job.query}\n" +
                   f"{job.errors=}")
        logger.error(err_str)
    return result


def _dedupe_table(input_table_id, output_table_id):
    bq_client = _get_bq_client()
    input_table_id = _get_fully_qualified_table_id(input_table_id)
    output_table_id = _get_fully_qualified_table_id(output_table_id)
    job_config = bigquery.QueryJobConfig(destination=output_table_id,
                                         write_disposition='WRITE_TRUNCATE')
    query_str = f"""
             SELECT
                 data
             FROM
                 `{input_table_id}`
             GROUP BY
                 data"""

    query_job = bq_client.query(query_str, job_config=job_config)
    _wait_for_query_job(query_job)
    logger.debug(f"deduped table {input_table_id}")


def _generate_credentials():
    if KEY_PATH is None:
        return None
    credentials = service_account.Credentials.from_service_account_file(
        KEY_PATH,
        scopes=[CLOUD_SCOPE, BIGQUERY_SCOPE],
    )
    return credentials


def _generate_clients():
    credentials = _generate_credentials()
    bq_client = bigquery.Client(project=PROJECT, credentials=credentials)
    return bq_client


def _get_bq_client():
    global BQ_CLIENT
    if BQ_CLIENT is None:
        BQ_CLIENT = _generate_clients()
    return BQ_CLIENT


def _get_fully_qualified_table_id(table_id):
    if PROJECT is None:
        raise ValueError(_NO_PROJECT_ERR_MSG)
    if PROJECT not in table_id:
        return f'{PROJECT}.{table_id}'
    else:
        return table_id


def _stream_in_batches(gen,
                       serializer,
                       table_id,
                       chunk_size=10000,
                       skip_sleep=False):
    """
    Load batches of outputs into a table specified by `table_id`. outputs are
    batched so csv files are around `csv_size_limit` bytes. The csv files
    are written in chunks of `chunk_size` before being checked for size.

    After streaming in the data, this function will sleep for 5 seconds.
    This is to give BigQuery enough time to process the new results, otherwise
    any queries that happen immediately after this function will sometimes
    not process the new data. You can set `skip_sleep` to True if you don't
    expect to make any queries soon after.

    :param gen: A generator of outputs to load into the table
    :type  gen: Iterator

    :param serializer: A serializer to serialize the outputs, see `Serializer`
    :type  serializer: Serializer

    :param table_id: The table to load the outputs into. Should include both
        dataset and table name, i.e. "<DATASET>.<TABLE>"
    :type  table_id: str
    """
    client = _get_bq_client()
    table_id = _get_fully_qualified_table_id(table_id)
    output_generator = more_itertools.peekable(gen)

    def get_chunk_of_rows(gen):
        rows = []
        for output in itertools.islice(gen, chunk_size):
            rows.append((serializer.toString(output),))
        return rows

    fields = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]

    while True:
        rows_to_insert = get_chunk_of_rows(output_generator)
        if not rows_to_insert:
            break
        errors = client.insert_rows(
            table_id, rows_to_insert,
            selected_fields=fields)  # Make an API request.
        if errors:
            logger.error(f"while streaming in rows: {errors}")
    if not skip_sleep:
        time.sleep(5)
    logger.debug(f"streamed in data to... {table_id}")


def _random_sample_table(src_table_id, dest_table_id, n):
    bq_client = _get_bq_client()
    input_table_id = _get_fully_qualified_table_id(src_table_id)
    output_table_id = _get_fully_qualified_table_id(dest_table_id)
    job_config = bigquery.QueryJobConfig(destination=output_table_id,
                                         write_disposition='WRITE_TRUNCATE')
    query_str = f"""
            SELECT
                data
            FROM
                `{input_table_id}`
            WHERE
                RAND() < {n}/(SELECT COUNT(*) FROM `{src_table_id}`);"""

    query_job = bq_client.query(query_str, job_config=job_config)
    _wait_for_query_job(query_job)


def _get_table_data(table_id, starting_idx=None, num_rows=None):
    """
    Get contents of a table specified by `table_id`. If `starting_idx` is
    specified, then the data will start at that row index. Up to `num_rows`
    will be returned.
    """
    if table_id is None:
        raise TypeError("table_id must be string, not None")
    bq_client = _get_bq_client()
    table_id = _get_fully_qualified_table_id(table_id)

    def _unwrap_row_iterator():
        for row in bq_client.list_rows(table_id,
                                       start_index=starting_idx,
                                       max_results=num_rows):
            yield row['data']

    return _unwrap_row_iterator()


def _table_row_count(table_id):
    # We use a query to get the number of rows. This has higher costs than
    # just checking a property on the table but table properties don't update
    # for streaming inserts.
    table_id = _get_fully_qualified_table_id(table_id)
    bq_client = _get_bq_client()
    query = ("SELECT data " f"FROM `{table_id}`")
    query_job = bq_client.query(query,)
    results = _wait_for_query_job(query_job)
    return results.total_rows


#===============================================================================
# Data and Setting Classes
#===============================================================================
[docs]class BQTable(parameters.CompoundParam): dataset: str = 'bq_testing_dataset' table_name: str = None
[docs] def getFullTableId(self): if not self.table_name: raise ValueError("Table name is not specified.") return self.dataset.replace('.', '-') + '.' + self.table_name.replace( '.', '-')
[docs]class TableChunk(parameters.CompoundParam): start_idx: int chunk_size: int table: BQTable
class _DownloadSettings(BQTable): chunk_size: int = 10000 #=============================================================================== # Steps #=============================================================================== class _UploadToBQStep(basesteps.UploadStep): Settings = BQTable Output = BQTable def reduceFunction(self, inps): table_id = self.settings.getFullTableId() _stream_in_batches(inps, self._getInputSerializer(), table_id) yield self.settings def validateSettings(self): if PROJECT is None: return [stepper.SettingsError(self, _NO_PROJECT_ERR_MSG)] return [] class _ChunkBigQueryTable(stepper.MapStep): class Settings(parameters.CompoundParam): chunk_size: int = 10000 Input = BQTable Output = TableChunk def mapFunction(self, table): num_rows = _table_row_count(table.getFullTableId()) for idx in range(0, num_rows, self.settings.chunk_size): table_copy = copy.deepcopy(table) yield TableChunk(start_idx=idx, chunk_size=self.settings.chunk_size, table=table_copy) class _DownloadFromBQStep(basesteps.DownloadStep): Input = TableChunk def mapFunction(self, inp): op_serializer = self.getOutputSerializer() for datum in _get_table_data(inp.table.getFullTableId(), starting_idx=inp.start_idx, num_rows=inp.chunk_size): yield op_serializer.fromString(datum) class _DedupeStep(basesteps.TableReduceStep): Input = BQTable Output = BQTable def _actOnTable(self, table_id): _dedupe_table(table_id, table_id) class _RandomSampleStep(basesteps.TableReduceStep): """ Deduplication step with random sampling enabled. Sampling occurs after deduplication. The config's `n` specifies the average number of rows to keep. """ class Settings(parameters.CompoundParam): n: int = 5000 Output = BQTable Input = BQTable def _actOnTable(self, table_id): _random_sample_table(table_id, table_id, self.settings.n) #=============================================================================== # Chains # To use these chains with pubsub, specify `use_pubsub` in the batch settings # for the steps `_UploadToBQStep` and `_DownloadFromBQStep` #=============================================================================== def _generate_stepid_and_random_suffix(step): return f"{step.getStepId()}_{str(uuid.uuid4())[:8]}" def _validate_table(step, table): errs = [] if not table.dataset: errs.append(stepper.SettingsError(step, NO_DATASET)) return errs def _setup_table(step, table): if table.table_name is None: table.table_name = _generate_stepid_and_random_suffix(step) table_id = table.getFullTableId() _create_table(table_id)
[docs]class BQUniqueSmilesFilter(MolMolMixin, stepper.Chain): """ A Chain that takes in Mol's, uploads them to BigQuery, and deduplicates them. To use, set the dataset and table name you'd like to use in the step settings. A table will be created in that dataset with that name. The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment variable. """ Settings = BQTable
[docs] def setUp(self): _setup_table(self, self.settings) self[0].setSettings(**self.settings.toDict()) # This is a bit of a hack to fix a bug when running this chain # with two levels of jobcontrol. See AD-359 self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self): self.addStep(_UploadToBQStep(**self.settings.toDict())) self.addStep(_DedupeStep()) self.addStep(_ChunkBigQueryTable()) self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self): ret = super().validateSettings() return ret + _validate_table(self, self.settings)
[docs]class BQRandomSampleFilter(MolMolMixin, stepper.Chain): """ A Chain that takes in Mol's, uploads them to BigQuery, and outputs a random sample of them. To use, set the dataset and table name you'd like to use in the step settings. A table will be created in that dataset with that name. The settings also has a `n` setting for determining roughly how many rows should be sampled. Note that this is an approximate number and a few more or less may be output. The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment variable. """
[docs] class Settings(parameters.CompoundParam): table: BQTable n: int = 5000
[docs] def setUp(self): _setup_table(self, self.settings.table) self[0].setSettings(**self.settings.table.toDict()) # This is a bit of a hack to fix a bug when running this chain # with two levels of jobcontrol. See AD-359 self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self): self.addStep(_UploadToBQStep(**self.settings.table.toDict())) self.addStep(_RandomSampleStep(n=self.settings.n)) self.addStep(_ChunkBigQueryTable()) self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self): ret = super().validateSettings() return ret + _validate_table(self, self.settings.table)
[docs]class BQDedupeAndRandomSampleFilter(MolMolMixin, stepper.Chain): """ Same as BQRandomSampleFilter except the data is deduplicated before randomly sampled. """
[docs] class Settings(parameters.CompoundParam): table: BQTable n: int = 5000
[docs] def setUp(self): _setup_table(self, self.settings.table) self[0].setSettings(**self.settings.table.toDict()) # This is a bit of a hack to fix a bug when running this chain # with two levels of jobcontrol. See AD-359 self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self): self.addStep(_UploadToBQStep(**self.settings.table.toDict())) self.addStep(_DedupeStep()) self.addStep(_RandomSampleStep(n=self.settings.n)) self.addStep(_ChunkBigQueryTable()) self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self): ret = super().validateSettings() return ret + _validate_table(self, self.settings.table)