Source code for schrodinger.test.stu.client

"""
Client for communication with STU server.  The core functionality is in
`TestClient` and `ResultReporter`, but most is accessible from the module
level convenience functions `create`, `retrieve`, `update`, `delete`, and
`download`, which allow modification of the tests in the database.

@copyright: (c) Schrodinger, LLC. All rights reserved.
"""
import getpass
import io
import json
import os
import os.path
import sys
import tempfile
import time  # noqa: F401, mocked in test
import urllib
import zipfile

import backoff
import requests
# Disable InsecurePlatformWarning until we update to > Python 2.7.9
import requests.packages.urllib3

from http import HTTPStatus
from schrodinger.test.stu import common
from schrodinger.test.stu import testscripts
from schrodinger.utils import fileutils
from schrodinger.utils import machid

# Hack for SHARED-4292
if sys.platform == "win32" or sys.platform.startswith("darwin"):
    import ssl
    ssl.HAS_SNI = False

requests.packages.urllib3.disable_warnings()

logger = common.logger

_VERSION = 1
"""API version to use."""

_JSON_FORMAT = {'content-type': 'application/json'}
MAX_MESSAGE_LENGTH = 10000

SCIVAL_TAGS = {'scival', 'require:scival'}

# Adjust the default number of allowed requests to the same URL.
requests.adapters.DEFAULT_RETRIES = 3

# Requests exceptions on which backoff should retry
RETRIABLE_EXCEPTIONS = (requests.exceptions.ConnectionError,
                        requests.exceptions.RequestException,
                        requests.exceptions.SSLError)
# RequestException status codes that should be retried
RETRIABLE_STATUS_CODES = (
    HTTPStatus.REQUEST_TIMEOUT,  # 408
    HTTPStatus.LENGTH_REQUIRED,  # 411
    499,  # not available in HTTPStatus
    HTTPStatus.INTERNAL_SERVER_ERROR,  # 500
    HTTPStatus.BAD_GATEWAY,  # 502
    HTTPStatus.SERVICE_UNAVAILABLE,  # 503
)

# Connection timeout should be "slightly larger than a multiple of 3, which is
# the default TCP packet retransmission window."
# (see http://docs.python-requests.org/en/master/user/advanced/#timeouts)
CONNECTION_TIMEOUT_SECONDS = 60.1
READ_TIMEOUT_SECONDS = 3600

STU_REMOTE_USER = "STU_REMOTE_USER"


[docs]class ClientError(Exception): """Generic STU client error."""
[docs]class ClientValueError(ValueError): """ Arguments to a client function are incorrect, or not fully determined. """
######################### # Convenience functions #########################
[docs]def create(username, test, directory=None, upload=True): myclient = TestClient.instance(username) test.id = None number = myclient.create(test) try: if number: test.id = number if upload: myclient.upload(test, directory) except: myclient.delete(number) raise return number
[docs]def retrieve(username, *args, **kwargs): myclient = TestClient.instance(username) return myclient.retrieve(*args, **kwargs)
[docs]def update(username, test, directory=None, upload=True): myclient = TestClient.instance(username) status = myclient.update(test) if status and upload: status = myclient.upload(test, directory) return status
[docs]def delete(username, test): myclient = TestClient.instance(username) return myclient.delete(test)
[docs]def download(username, test, directory=None, overwrite=True): myclient = TestClient.instance(username) return myclient.download(test, directory, overwrite)
[docs]class ResultReporter: """ Reporter that will upload results and files for a specific test run. (This is the `report` method.) Also marks runs as complete, with the option to send an email about test failures to the interested users (the `completeRun` method). """
[docs] def __init__(self, buildtype, build_id, mmshare, local_system, remote_system, username=None, release=None, build_log_address=None, comment=None, base_url=None, api_version=_VERSION): self.client = _BaseClient(username, base_url, api_version) release = release or local_system.release self._local_system = local_system self._remote_system = remote_system self._localhost = None self._remotehost = None self.run = self._createRun(mmshare, release, buildtype, build_id, build_log_address, comment)
[docs] def report(self, test, upload=True, files=None): """Report the result of ONE test and upload its files.""" data = self._getTestData(test) with tempfile.TemporaryFile('w+b') as content: if upload: size = 0 try: if files is None: common.zip_directory(test.directory, content) else: common.zip_files(content, test.directory, files) size = content.tell() # Verify zip contents content.seek(0) common.verify_zip(content) except common.ZipError as err: # Roundtrip string to utf-8 for prep for server err_name = err.__doc__.encode( 'utf-8', errors="replace").decode('utf-8') set_failure_data(data, err_name, str(err)) content.seek(0) if upload and size: # BLDMGR-3781 apparently requests doesn't read the HTTP response # until it finishes the upload. This is a problem when it tries to # upload a file larger than 2000 MB (the body limit set in nginx) # because it just comes back with connection reset by peer, rather # than HTTP 413 (Request Entity Too Large). Check file size here # and fail the test if over the limit. limit = 2000 * 1024 * 1024 if size > limit: err_message = ( 'Test results are too large, they must not ' 'exceed 2000 MB when compressed to a ZIP file.' 'Current ZIP size is {:0.2f} MB, {} bytes too ' 'large. Edit your test to decrease the result ' 'size.'.format( float(size) / (1024.0 * 1024.0), size - limit)) err_name = 'Result exceeds size limit' set_failure_data(data, err_name, err_message) data = json.dumps(data) try: response = self.client.post(self.client.item_uri('outcome'), data=data) except requests.exceptions.HTTPError as err: if err.response.status_code == 410: logger.debug(f' {test} was deleted. Skipping report.') return raise except: msg = 'Failed to report for {}. Requested data: {}'.format( test, data) logger.exception(msg) raise if not (upload and size): return address = response.headers['location'] + 'upload/' self.client.post(address, data=content, headers={ 'content-encoding': 'application/zip', })
[docs] def userRunURL(self): """Return the URL where users can go to see details about this run.""" url = self.client._base + self.run url = url.replace('api/v1/', '') return url
[docs] def completeRun(self, duration, email=False): """ Record that the run is complete, and include the total duration. If email is True, trigger an email about test failures. """ data = dict(total_time=duration, complete=True) data = json.dumps(data) if email: # Doesn't actually create a resource, so status code is not 201. self.client.post(self.run + 'email/', data=data, required_statuses=None) else: self.client.patch(self.run, data=data)
def _getOrCreateBuild(self, mmshare, release, buildtype, build_id): data = dict(mmshare=mmshare, release=release, buildtype=buildtype, build_id=build_id) if not all(data.values()): raise ClientValueError('Build is not fully identified - some ' 'required values are blank. Values: %s' % data) build_api = self.client.item_uri('build') response = self.client.get(build_api, params=data, data=data) response = response.json() if response['meta']['total_count'] > 1: msg = ( 'More than one build found: %s' % ', '.join(val['resource_uri'] for val in response['objects'])) raise ClientValueError(msg) elif not response['meta']['total_count']: flat_git_hashes = {} for product_hashes in machid.get_product_git_hashes().values(): flat_git_hashes.update(product_hashes) data['_git_hashes'] = flat_git_hashes response = self.client.post(build_api, data=json.dumps(data), required_statuses=(201, 303)) response = response.json() return response['resource_uri'] else: return response['objects'][0]['resource_uri'] def _createRun(self, mmshare, release, buildtype, build_id, build_log_address, comments): build = self._getOrCreateBuild(mmshare, release, buildtype, build_id) data = dict(build=build, build_log_address=build_log_address, build_results_dir=None, platform=self._local_system.platform, executor=self.client._username, comments=comments, schrodinger=self._local_system.schrodinger, total_time=0, complete=False, localhost=self.localhost, remotehost=self.remotehost) data = json.dumps(data) response = self.client.post(self.client.item_uri('run'), data=data) return self.client.getResourceURI(response) @property def localhost(self): if not self._localhost: self._localhost = self._getOrCreateHost( **self._local_system.toDict()) return self._localhost @property def remotehost(self): if not self._remotehost: if self._localhost and self._local_system == self._remote_system: self._remotehost = self._localhost else: self._remotehost = self._getOrCreateHost( **self._remote_system.toDict()) return self._remotehost def _getOrCreateHost(self, **kwargs): response = self.client.get('host', params=dict(limit=1, **kwargs)) objects = response.json()['objects'] if objects: return objects[0]['resource_uri'] response = self.client.post('host', data=json.dumps(kwargs)) location = self.client.getResourceURI(response) return location def _getTestData(self, test): local = not test.runsRemotely() if len(test.workup_messages) > MAX_MESSAGE_LENGTH: suffix = '\n\nTRUNCATED due to message length > {}.\n' suffix = suffix.format(MAX_MESSAGE_LENGTH) workup_messages = test.workup_messages[:MAX_MESSAGE_LENGTH] workup_messages += suffix else: workup_messages = test.workup_messages post_data = dict(test=test.resource_uri, passed=test.outcome, run_time_sec=test.timing, run=self.run, workup_messages=workup_messages, failure_type=dict(name=test.failure_type)) if local: post_data['host'] = self.localhost else: post_data['host'] = self.remotehost return post_data
[docs]def set_failure_data(data, error_name, error_message): """ Modify test data in-place to indicate a test failure to the STU API. """ data['passed'] = False data['workup_messages'] = '{}\n\n\n{}'.format(error_message, data['workup_messages']) data['failure_type'] = error_name
[docs]def raise_for_status(response): """ If `response` has a bad status, raise an Exception. First, however, be sure to print any data available from the exception. :type response: requests.models.Response :param response: Response to check for exit status problems. """ try: response.raise_for_status() except requests.exceptions.HTTPError as err: err_args = list(err.args) try: data = response.json() error = data.get('error_message', None) error = error or data.get('error', None) error = error or response.reason if error and not err_args[0].endswith(error): err_args[0] = '{}, {}'.format(err_args[0], error) print('Server ' + data.get('traceback', '').replace('\n\n', '\n')) # Response does not contain valid json except ValueError: description = getattr(response, 'text', False) if description and not err_args[0].endswith(description): err_args[0] = '{}, {}'.format(err_args[0], description) err_args[0] += f" ({response.url} : {response.headers} )" err.args = tuple(err_args) raise err
[docs]def check_status(response, required_statuses): """ Raise an error if the status does not match `required_status`. :type required_status: int :param required_status: Status to match """ if response.status_code not in required_statuses: msg = ('Response %s (%s) did not match required status "%s"' % (response.reason, response.status_code, required_statuses)) # Match exception class of raise_for_status raise requests.exceptions.HTTPError(msg) return True
[docs]def fatal_status(exception): """ This method should return True only if exception is an HTTPError, and the status code is NOT retriable. """ if not isinstance(exception, requests.exceptions.HTTPError): return False # HTTPErrors should normally have a response, but if the calling code does # "raise HTTPError()", the response member is initialized to None if not hasattr(exception.response, "status_code"): return False return exception.response.status_code not in RETRIABLE_STATUS_CODES
[docs]class ApiKeyAuth(requests.auth.AuthBase): """ An authorization method that uses an api key. """
[docs] def __init__(self, username): api_key = common.get_api_key() self._auth = f'ApiKey {username}:{api_key}'
def __call__(self, r): r.headers['Authorization'] = self._auth return r
class _BaseClient: """ Base client class for submitting requests to a server. Adds defaults for server url, authentication, and retrying. """ _INSTANCE = None """ Singleton instance of the client to simplify access for apps that always use the defaults. """ def __init__(self, username=None, base_url=None, api_version=_VERSION): if not base_url: base_url = common.BASE_URL self._base = base_url.rstrip('/') self._api = '/api/v%s/' % api_version self._auth = None self._username = username self._session = requests.Session() @classmethod def instance(cls, username, base_url=None, api_version=_VERSION): if not cls._INSTANCE: cls._INSTANCE = cls(username, base_url, api_version) if username != cls._INSTANCE._username: cls._INSTANCE._auth = None cls._INSTANCE._username = username return cls._INSTANCE @property def auth(self): if not self._auth: self._auth = ApiKeyAuth(self._username) return self._auth @property def fullapi(self): return self._base + self._api def item_uri(self, item): return f"{self.fullapi}{item}/" def list2str(self, items): return ','.join(str(x) for x in items) def safename(self, oldname): """ Ensure that the name is safe and that path separators are consistent on Linux and Windows. :type oldname: str :param oldname: String to be protected :rtype: str :return: String with with all path separators replaced by / and all non-URL-safe characters protected. """ newname = oldname.replace(os.path.sep, '/') newname = newname.replace('\\', '/') return urllib.parse.quote(newname, '') @backoff.on_exception(backoff.expo, RETRIABLE_EXCEPTIONS, max_tries=4, giveup=fatal_status, logger=logger, factor=30) def _request(self, method, uri, headers=_JSON_FORMAT, auth=None, required_statuses=None, verify=False, **kwargs): if auth is None: auth = self.auth if not uri.startswith('http'): if uri.startswith('/'): uri = self._base + uri else: uri = self.item_uri(uri) response = method(uri, headers=headers, auth=auth, verify=verify, timeout=(CONNECTION_TIMEOUT_SECONDS, READ_TIMEOUT_SECONDS), **kwargs) raise_for_status(response) if required_statuses: check_status(response, required_statuses) return response def post(self, uri, required_statuses=(201,), **kwargs): return self._request(self._session.post, uri, required_statuses=required_statuses, **kwargs) def put(self, uri, **kwargs): return self._request(self._session.put, uri, **kwargs) def patch(self, uri, required_statuses=(202,), **kwargs): return self._request(self._session.patch, uri, required_statuses=required_statuses, **kwargs) def get(self, uri, **kwargs): return self._request(self._session.get, uri, **kwargs) def delete(self, uri, **kwargs): return self._request(self._session.delete, uri, **kwargs) def getResourceURI(self, response=None): """Get an API rooted address from a full address.""" location = response.headers['location'] return location.replace(self._base, '')
[docs]class TestClient(_BaseClient): """ Interact with the Test server. Create, Retrieve, Update, and Delete tests. Also upload and download the associated files. """
[docs] def __init__(self, username=None, base_url=None, api_version=_VERSION): super().__init__(username=username, base_url=base_url, api_version=api_version) self._testapi = self.item_uri('test') self._sharedapi = self.item_uri('sharedfile')
[docs] def create(self, test): """ Create a test. """ test_dict = test.toDict() post_data = json.dumps(test_dict) response = self.post(self._testapi, data=post_data) return response.json()['number']
[docs] def retrieve( self, test_ids=None, products=None, components=None, priorities=None, tags=None, not_products=tuple(), # noqa: M511 not_components=tuple(), # noqa: M511 not_tags=tuple()): # noqa: M511 """ Retrieve tests based on some limiting criteria. The only method that operates on more than one test at a time. A bit weird? """ query_dict = dict(limit=0) if test_ids: query_dict['number__in'] = self.list2str(test_ids) if products: # Use the custom STU query that searches JIRA names too query_dict['products'] = self.list2str(products) if components: query_dict['component__in'] = self.list2str(components) if priorities: query_dict['priority__in'] = self.list2str(priorities) if tags: query_dict['tags__in'] = self.list2str(tags) not_products = [product.lower() for product in not_products or tuple()] not_components = [ component.lower() for component in not_components or tuple() ] not_tags = [tag.lower() for tag in not_tags or tuple()] response = self.get(self._testapi, params=query_dict) data = response.json() if data['meta']['total_count'] != len(data['objects']): raise ClientError( 'Did not download all tests! (%s requested. %s ' 'found' % (data['meta']['total_count'], len(data['objects']))) for test_data in data['objects']: # Filter tests that use products, tags and subfeatures from the # skip lists. if test_data['product'].lower() in not_products: continue if test_data['component'] and test_data['component'].lower( ) in not_components: continue matches = [tag for tag in test_data['tags'] if tag in not_tags] if matches: if has_scival_tags(test_data['tags']): msg = f'Skipping scival test: {test_data["number"]}' if test_ids and test_data['number'] in test_ids: logger.critical(msg) else: logger.debug(msg) continue # Backwards compatibility for SHARED-3063 if 'jira_tracking' not in test_data and test_data.get( 'bug_disabled'): test_data['disabled_for_bug'] = True test_data['jira_tracking'] = test_data.pop('bug_disabled') test = testscripts.TestScript(**test_data) yield test
[docs] def find_one(self, criteria): """ Find a single STU test that matches the search criteria. :param criteria: Search criteria, uses the exact names of parameters on the server. (does not interpret them in the way that retrieve does) :type criteria: dict :raise ClientError: If more than one test is found. :raise IndexError: If no tests are found. """ response = self.get(self._testapi, params=criteria) data = response.json() total_count = data['meta']['total_count'] if total_count > 1: raise ClientError( 'Found too many tests matching criteria {}. Found {} tests'. format(criteria, total_count)) try: test_data = data['objects'][0] except IndexError: raise IndexError( 'Did not find a test matching the criteria: {}'.format( criteria)) return testscripts.TestScript(**test_data)
[docs] def get_or_create(self, search_criteria, creation_data=None): """ Find a test corresponding to search_criteria. If no test is found, create a test using creation_data. :param search_criteria: Search criteria. :type search_criteria: dict :param creation_data: Data to be used to create the test if none exists. :type creation_data: dict """ try: test = self.find_one(search_criteria) except IndexError: # Doesn't exist yet. if creation_data is None: creation_data = search_criteria test = testscripts.TestScript(**creation_data) creation_data = creation_data.copy() creation_data['number'] = self.create(test) # Need to get again to have access to the URL for the test. test = self.find_one(creation_data) assert test.id return test
[docs] def update(self, test): """ Update files and metadata for a test. """ test_dict = test.toDict() post_data = json.dumps(test_dict) self.put(self._testapi + '%s/' % test.id, data=post_data) return True
[docs] def delete(self, test): """ Delete a test metadata and files. """ if hasattr(test, 'id'): testid = test.id else: testid = int(test) super().delete(self._testapi + "%s/" % testid) return True
[docs] def download(self, test, directory=None, overwrite=True): """ Download and extract the files associated with `test`. Overwrites existing contents of the directory. """ if has_scival_tags(test.tags): raise ClientValueError(f'Cannot download scival test: {test}.') directory = self._getTestDir(test, directory, overwrite) self._getSharedFiles(test.shared_files, overwrite) return directory
[docs] def upload(self, test, directory=None): """ Upload a directory to the server. If the test has a directory attribute, use that. Otherwise, based on the test number. """ self._setSharedFiles(test.shared_files) self._setTestDir(test, directory) return True
def _setTestDir(self, test, directory=None): path = directory or test.directory or str(test.id) test.directory = path content = io.BytesIO() common.zip_directory(path, fileobj=content, skipped_files={'README'}) content = io.BytesIO(content.getvalue()) files = {'%s.zip' % test.id: content} if getattr(test, 'upload', None): uri = self._base + test.upload else: response = self.get(self._testapi, params={'number': test.id}) uri = self._base + response.json()['objects'][0]['upload'] response = self.post(uri, files=files, headers={'content-encoding': 'application/zip'}, required_statuses=None) def _setSharedFiles(self, shared_files): """ Upload shared files to server. """ for filename in shared_files: fullname = os.path.join('shared', filename) if os.path.isdir(fullname): raise RuntimeError('"%s" is a directory. Only files can be ' 'shared.' % fullname) # Check to see if it already exists on the server: # This metadata check should be much cheaper than sending files. uri = self._sharedapi + '%s' % self.safename(filename) response = self.get(uri) data = response.json() if data['thefile']: continue fileobj = open(fullname, 'rb') uri = self._base + data['upload'] response = self.post(uri, files={'file': fileobj}, headers=None, required_statuses=None) fileobj.close() def _getTestDir(self, test, directory=None, overwrite=True): """ Download the zipped test file and extract it to a directory. """ if not directory: directory = test.directory if overwrite and os.path.isdir(directory): fileutils.force_rmtree(directory) test.directory = directory if test.download: uri = self._base + test.download else: response = self.get(self._testapi, params={'number': test.id}) uri = self._base + response.json()['objects'][0]['download'] zipped = self._getZip(uri) if not zipped.namelist(): # empty zip file os.makedirs(directory, exist_ok=True) else: zipped.extractall(path=directory) test.write_readme(os.path.join(directory, 'README')) return directory # BLDMGR-3710 It seems that under high load, the server might return a # partial zip file with status 200. Retry twice before raising. # FIXME: remove retry logic once the server issue is fixed @backoff.on_exception(backoff.expo, zipfile.BadZipfile, max_tries=3, logger=logger, factor=60) def _getZip(self, uri): response = self.get(uri) content = io.BytesIO(response.content) return zipfile.ZipFile(content) def _getSharedFiles(self, shared_files, overwrite=False): """ Download the shared files. """ # Don't create the shared directory unless it is needed. if not shared_files: return None for filename in shared_files: local_name = os.path.join('shared', filename) if os.path.isfile(local_name) and not overwrite: continue directory = os.path.dirname(local_name) if directory and not os.path.isdir(directory): os.makedirs(directory) uri = self._sharedapi + '%s/download/' % filename response = self.get(uri) with open(local_name, 'wb') as fh: fh.write(response.content) return None
[docs]def has_scival_tags(test_tags): """ Inspect a test's tags to see if the test is a Scival test. :param tags: Tags for a STU test :type tags: list of str :rtype: bool """ return any(tag in test_tags for tag in SCIVAL_TAGS)
[docs]def get_stu_username(): """ Set STU user to the current user calling this code """ return os.getenv(STU_REMOTE_USER, default=getpass.getuser())