Source code for schedy.core

# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals
from builtins import *
from six import raise_from

from .experiments import Experiment, RandomSearch, ManualSearch, PopulationBasedTraining, _make_experiment
from .jwt import JWTTokenAuth
from .pagination import PageObjectsIterator
from . import errors, encoding
from .compat import json_dumps

import functools
import json
import requests
import os.path
import datetime
from requests.compat import urljoin, quote as urlquote
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import logging

logger = logging.getLogger(__name__)

class SchedyRetry(Retry):
    BACKOFF_MAX = 8 * 60

    def increment(self, method=None, url=None, response=None, error=None, *args, **kwargs):
        logger.warn('Error while querying Schedy service, retrying.')
        if response is not None:
            logger.warn('Server message: {!s}'.format(response.data))
        return super().increment(
            method=method,
            url=url,
            response=response,
            error=error,
            *args,
            **kwargs)

#: Number of retries if the authentication fails.
NUM_AUTH_RETRIES = 2

def _default_config_path():
    return os.path.join(os.path.expanduser('~'), '.schedy', 'client.json')

[docs]class SchedyDB(object): def __init__(self, config_path=None, config_override=None): ''' SchedyDB is the central component of Schedy. It represents your connection the the Schedy service. Args: config_path (str or file-object): Path to the client configuration file. This file contains your credentials (email, API token). By default, ~/.schedy/client.json is used. See :ref:`setup` for instructions about how to use this file. config_override (dict): Content of the configuration. You can use this to if you do not want to use a configuration file. ''' self._load_config(config_path, config_override) # Add the trailing slash if it's not there if len(self.root) == 0 or self.root[-1] != '/': self.root = self.root + '/' self._schedulers = dict() self._register_default_schedulers() self._jwt_token = None self._jwt_expiration = datetime.datetime(year=1970, month=1, day=1) self._session = None def _authenticate(self): ''' Renew authentication. You do not usually need to call this function, as it will always be called automatically when needed. ''' logger.debug('Renewing authentication') if self.token_type == 'password': url = urljoin(self.root, 'passauth/') else: url = urljoin(self.root, 'token/') response = self._perform_request('POST', url, json={'email': self.email, 'token': self.api_token}) errors._handle_response_errors(response) try: token_data = response.json() except ValueError as e: raise_from(errors.ServerError('Response contains invalid JSON:\n' + response.text, None), e) try: jwt_token = token_data['token'] expires_at = datetime.datetime.fromtimestamp(token_data['expiresAt']) except (KeyError, OverflowError, OSError) as e: raise_from(errors.ServerError('Response contains invalid token data.', None), e) self._jwt_token = JWTTokenAuth(jwt_token, expires_at) logger.debug('A new token was obtained.')
[docs] def add_experiment(self, exp): ''' Adds an experiment to the Schedy service. Use this function to create new experiments. Args: exp (schedy.Experiment): The experiment to add. Example: >>> db = schedy.SchedyDB() >>> exp = schedy.ManualSearch('TestExperiment') >>> db.add_experiment(exp) ''' url = self._experiment_url(exp.name) content = exp._to_map_definition() data = json_dumps(content, cls=encoding.SchedyJSONEncoder) response = self._authenticated_request('PUT', url, data=data, headers={'If-None-Match': '*'}) # Handle code 412: Precondition failed if response.status_code == requests.codes.precondition_failed: raise errors.ResourceExistsError(response.text, response.status_code) else: errors._handle_response_errors(response) exp._db = self
[docs] def get_experiment(self, name): ''' Retrieves an experiment from the Schedy service by name. Args: name (str): Name of the experiment. Returns: schedy.Experiment: An experiment of the appropriate type. Example: >>> db = schedy.SchedyDB() >>> exp = db.get_experiment('TestExperiment') >>> print(type(exp)) <class 'schedy.experiments.ManualSearch'> ''' url = self._experiment_url(name) response = self._authenticated_request('GET', url) errors._handle_response_errors(response) try: content = dict(response.json()) except ValueError as e: raise_from(errors.ServerError('Response contains invalid JSON dict:\n' + response.text, None), e) try: exp = Experiment._from_map_definition(self._schedulers, content) except ValueError as e: raise_from(errors.ServerError('Response contains an invalid experiment', None), e) exp._db = self return exp
[docs] def get_experiments(self): ''' Retrieves all the experiments from the Schedy service. Returns: iterator of :py:class:`schedy.Experiment`: Iterator over all the experiments. ''' return PageObjectsIterator( reqfunc=functools.partial(self._authenticated_request, 'GET', self._all_experiments_url()), obj_creation_func=functools.partial(_make_experiment, self), )
def _register_scheduler(self, experiment_type): ''' Registers a new type of experiment. You should never have to use this function yourself. Args: experiment_type (class): Type of the experiment, it must have an attribute called _SCHEDULER_NAME. ''' self._schedulers[experiment_type._SCHEDULER_NAME] = experiment_type def _register_default_schedulers(self): self._register_scheduler(RandomSearch) self._register_scheduler(ManualSearch) self._register_scheduler(PopulationBasedTraining) def _all_experiments_url(self): return urljoin(self.root, 'experiments/') def _experiment_url(self, name): return urljoin(self._all_experiments_url(), '{}/'.format(urlquote(name, safe=''))) def _job_url(self, experiment, job): return urljoin(self.root, 'experiments/{}/jobs/{}/'.format(urlquote(experiment, safe=''), urlquote(job, safe=''))) def _load_config(self, config_path, config): if config is None: if config_path is None: config_path = _default_config_path() if hasattr(config_path, 'read'): config = json.loads(config_path.read()) else: with open(config_path) as f: config = json.load(f) self.root = config['root'] self.email = config['email'] self.token_type = config.get('token_type', 'api_token') allowed_token_types = ['api_token', 'password'] if self.token_type not in allowed_token_types: raise ValueError('Configuration value token_type must be one of {}.'.format(', '.join(allowed_token_types))) self.api_token = config['token'] def _authenticated_request(self, *args, **kwargs): response = None for _ in range(NUM_AUTH_RETRIES): if self._jwt_token is None or self._jwt_token.expires_soon(): self._authenticate() response = self._perform_request(*args, auth=self._jwt_token, **kwargs) if response.status_code != requests.codes.unauthorized: break return response def _make_session(self): self._session = requests.Session() retry_mgr = SchedyRetry( total=10, read=10, connect=10, backoff_factor=0.4, status_forcelist=frozenset((requests.codes.server_error, requests.codes.unavailable)), # Careful: POST and PATCH are in the whitelist. This means that # the server should not be in an incomplete state or POSTING # and PATCHING twice could do weird things. We do this because # we do not want Schedy to crash in the face of the user when # there's a connection or benign error. method_whitelist=frozenset(('HEAD', 'TRACE', 'GET', 'PUT', 'OPTIONS', 'DELETE', 'POST', 'PATCH')), ) adapter = HTTPAdapter(max_retries=retry_mgr) self._session.mount('http://', adapter) self._session.mount('https://', adapter) def _perform_request(self, *args, **kwargs): if self._session is None: self._make_session() if 'data' in kwargs: logger.debug('Sent headers: %s', kwargs.get('headers')) logger.debug('Sent data: %s', kwargs['data']) req = self._session.request(*args, **kwargs) logger.debug('Received headers: %s', req.headers) logger.debug('Received data: %s', req.text) return req