Source code for track.persistence.socketed

"""

    Implement a Remote Logger.
        Client forwards all the user's request down to the server that executes them one by one.


"""
from track.utils.signal import SignalHandler
from track.persistence.protocol import Protocol
from track.persistence.utils import parse_uri
from track.utils import open_socket, listen_socket
from track.aggregators.aggregator import Aggregator, StatAggregator
from track.structure import Trial, TrialGroup, Project
from track.serialization import to_json, from_json
from track.utils.log import error, warning, info
from track.utils.throttle import throttle_repeated

from typing import Callable

import traceback
import time
import asyncio
import json
import struct


[docs]def to_bytes(message) -> bytes: return json.dumps(message).encode('utf8')
[docs]def to_obj(message: bytes) -> any: return from_json(json.loads(message))
[docs]def send(socket, msg): bytes = to_bytes(msg) size = bytearray(struct.pack('I', len(bytes) + 4)) socket.sendall(size + bytes)
[docs]def recv(socket, timeout=None): data = socket.recv(4096) size = struct.unpack('I', data[0:4])[0] elapsed = 0 while len(data) < size: data += socket.recv(size) if len(data) == 0: time.sleep(0.01) elapsed += 0.01 if timeout and elapsed > timeout: raise TimeoutError('Was not able to receive the entire message in time') return to_obj(data[4:size])
[docs]class RPCCallFailure(Exception): def __init__(self, message, trace=None): super(RPCCallFailure, self).__init__(message)
def _check(result): if result['status'] != 0: error(f'RPC failed with error {result["error"]}') raise RPCCallFailure(result['error'], result.get('trace')) return from_json(result['return'])
[docs]class SocketClient(Protocol): """Forwards all the local track requests to the track server that execute the requests and send back the results Clients can provide a username and password for authentication """ # socket://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]] def __init__(self, uri): uri = parse_uri(uri) self.username = uri.get('username') self.password = uri.get('password') self.security_layer = uri['query'].get('security_layer') self.socket = open_socket(uri.get('address'), int(uri.get('port')), backend=self.security_layer) self.token = self._authenticate(uri) info(f'token: {self.token}')
[docs] def authenticate(self, uri): """returns the username and password used for authentication purposes you can override this function to implement a custom authentication method """ return self.username, self.password
def _authenticate(self, uri): username, password = self.authenticate(uri) # Should we send the password hashed ? The connection should be secure regardless # Plus how would we handle salting kwargs = dict() kwargs['__rpc__'] = 'authenticate' kwargs['username'] = username kwargs['password'] = password send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_chrono_start(self, trial, name: str, aggregator: Callable[[], Aggregator] = StatAggregator.lazy(1), start_callback=None, end_callback=None): kwargs = dict() kwargs['__rpc__'] = 'log_trial_chrono_start' kwargs['trial'] = trial.uid kwargs['name'] = name send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_chrono_finish(self, trial, name, exc_type, exc_val, exc_tb): kwargs = dict() kwargs['__rpc__'] = 'log_trial_chrono_finish' kwargs['trial'] = trial.uid kwargs['name'] = name kwargs['exc_type'] = None kwargs['exc_val'] = None kwargs['exc_tb'] = None send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_start(self, trial): kwargs = dict() kwargs['__rpc__'] = 'log_trial_start' kwargs['trial'] = trial.uid send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_finish(self, trial, exc_type, exc_val, exc_tb): kwargs = dict() kwargs['__rpc__'] = 'log_trial_finish' kwargs['trial'] = trial.uid kwargs['exc_type'] = None kwargs['exc_val'] = None kwargs['exc_tb'] = None send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_arguments(self, trial: Trial, **kwargs): kwargs['__rpc__'] = 'log_trial_arguments' kwargs['trial'] = trial.uid send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_metadata(self, trial: Trial, aggregator: Callable[[], Aggregator] = None, **kwargs): kwargs['__rpc__'] = 'log_trial_metadata' kwargs['trial'] = trial.uid send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def log_trial_metrics(self, trial: Trial, step: any = None, aggregator: Callable[[], Aggregator] = None, **kwargs): kwargs['__rpc__'] = 'log_trial_metrics' kwargs['trial'] = trial.uid kwargs['step'] = step send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def set_trial_status(self, trial: Trial, status, error=None): kwargs = dict() kwargs['__rpc__'] = 'set_trial_status' kwargs['trial'] = trial.uid kwargs['status'] = to_json(status) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def add_trial_tags(self, trial, **kwargs): kwargs['__rpc__'] = 'add_trial_tags' kwargs['trial'] = trial.uid send(self.socket, kwargs) return _check(recv(self.socket))
# Object Creation
[docs] def get_project(self, project: Project): kwargs = dict() kwargs['__rpc__'] = 'get_project' kwargs['project'] = to_json(project) info(kwargs) send(self.socket, kwargs) p = _check(recv(self.socket)) info(f'got reply {p}') return p
[docs] def new_project(self, project: Project): kwargs = dict() kwargs['__rpc__'] = 'new_project' kwargs['project'] = to_json(project) send(self.socket, kwargs) p = _check(recv(self.socket)) return p
[docs] def get_trial_group(self, group: TrialGroup): kwargs = dict() kwargs['__rpc__'] = 'get_trial_group' kwargs['group'] = to_json(group) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def new_trial_group(self, group: TrialGroup): kwargs = dict() kwargs['__rpc__'] = 'new_trial_group' kwargs['group'] = to_json(group) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def add_project_trial(self, project: Project, trial: Trial): kwargs = dict() kwargs['__rpc__'] = 'add_project_trial' kwargs['project'] = to_json(project) kwargs['trial'] = to_json(trial) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def add_group_trial(self, group: TrialGroup, trial: Trial): kwargs = dict() kwargs['__rpc__'] = 'add_group_trial' kwargs['group'] = to_json(group) kwargs['trial'] = to_json(trial) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def commit(self, **kwargs): kwargs['__rpc__'] = 'commit' send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def get_trial(self, trial: Trial): kwargs = dict() kwargs['__rpc__'] = 'get_trail' kwargs['trial'] = to_json(trial) send(self.socket, kwargs) return _check(recv(self.socket))
[docs] def new_trial(self, trial: Trial): kwargs = dict() kwargs['__rpc__'] = 'new_trial' kwargs['trial'] = to_json(trial) send(self.socket, kwargs) return _check(recv(self.socket))
[docs]async def read(reader, timeout=None): data = await (reader.read(4096)) if not data: return None size = struct.unpack('I', data[0:4])[0] elapsed = 0 while len(data) < size: data += await reader.read(4096) if len(data) == 0: time.sleep(0.01) elapsed += 0.01 if timeout and elapsed > timeout: raise TimeoutError('Was not able to receive the entire message in time') return to_obj(data[4:])
[docs]def write(writer, msg): bytes = to_bytes(msg) size = bytearray(struct.pack('I', len(bytes) + 4)) writer.write(size + bytes)
[docs]class SocketServer(Protocol): """Start a track server inside a asyncio loop Parameters ---------- uri: str socket://{hostname}:{port}?security_layer={}&backend={protocol} with Users inherit this class to implement their own custom authentication """ def __init__(self, uri): from track.persistence import get_protocol uri = parse_uri(uri) self.address, self.port = uri.get('address'), int(uri.get('port')) self.security_layer = uri['query'].get('security_layer') self.backend = get_protocol(uri['query'].get('backend')) self.authentication = {} self.timeout = 10 self.client_cache = {} self.sckt = None self.loop = None
[docs] def authenticate(self, reader, username, password): """User defined authentication function Parameters ---------- reader: StreamReader client socket / reader, can be used to link client socket -> username username: str client username password: str client password """ self.authentication[reader] = (username, password) return { 'status': 0, 'return': True }
# https://stackoverflow.com/questions/48506460/python-simple-socket-client-server-using-asyncio
[docs] def run_server(self): info(f'Server listening to {self.address}:{self.port}') self.sckt = listen_socket(self.address, self.port, backend=self.security_layer) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(asyncio.start_server(self.handle_client, sock=self.sckt)) loop.run_forever() self.loop = loop
[docs] def process_args(self, args, cache=None): """ replace ids by their object reference so the backend modifies the objects and not a copy""" new_args = dict() for k, v in args.items(): if k == 'trial': if isinstance(v, str): hashid, rev = v.split('_') rev = int(rev) v = self.backend.get_trial(Trial(_hash=hashid, revision=rev)) for i in v: if i.revision == rev: v = i break else: warning('Was not able to find the correct trial revision') v = from_json(v) elif k == 'project': if isinstance(v, str): v = self.backend.get_project(Project(name=v)) v = from_json(v) elif k == 'group': if isinstance(v, str): v = self.backend.get_trial_group(TrialGroup(_uid=v)) v = from_json(v) new_args[k] = v return new_args
[docs] def exec(self, reader, writer, proc_name, proc, args, cache=None): try: new_args = self.process_args(args) answer = proc(**new_args) write(writer, { 'status': 0, 'return': to_json(answer) }) # info(f'returned: {answer}') except Exception as e: error(f'An exception occurred while processing (rpc: {proc_name}) ' f'for user {self.get_username(reader)[0]}') error(traceback.format_exc()) write(writer, { 'status': 1, 'error': str(e) })
[docs] @staticmethod async def wait_closed(writer): try: await writer.wait_closed() # wait_closed is python 3.7+ except AttributeError: pass
[docs] @staticmethod async def close_connection(writer): await writer.drain() writer.close() await SocketServer.wait_closed(writer)
[docs] async def handle_client(self, reader, writer): info('Client Connected') running = True count = 0 sleep_time = 0 cache = {} info_proc = throttle_repeated(info, every=10) while running: request = await read(reader) count += 1 if request is None: time.sleep(0.01) sleep_time += 0.01 if sleep_time > self.timeout: info(f'Client (user: {self.get_username(reader)}) is timing out') await self.close_connection(writer) self.authentication.pop(reader, None) return None continue proc_name = request.pop('__rpc__', None) info_proc(f'Processing Request: {proc_name} for (user: {self.get_username(reader)})') if proc_name is None: error(f'Could not process message (rpc: {request})') write(writer, { 'status': 1, 'error': f'Could not process message (rpc: {request})' }) continue elif proc_name == 'authenticate': request['reader'] = reader self.exec(reader, writer, proc_name, self.authenticate, request, cache=cache) continue elif not self.is_authenticated(reader): error(f'Client is not authenticated cannot execute (proc: {proc_name})') write(writer, { 'status': 1, 'error': f'Client is not authenticated cannot execute (proc: {proc_name})' }) continue # Forward request to backend attr = getattr(self.backend, proc_name) if attr is None: error(f'{self.backend.__name__} does not implement (rpc: {proc_name})') write(writer, { 'status': 1, 'error': f'{self.backend.__name__} does not implement (rpc: {proc_name})' }) continue self.exec(reader, writer, proc_name, attr, request, cache=cache) sleep_time = 0 self.authentication.pop(reader, None)
[docs] def get_username(self, reader): usr_pwd = self.authentication.get(reader) if usr_pwd is None: return None return usr_pwd[0]
[docs] def is_authenticated(self, reader): return self.authentication.get(reader) is not None
[docs] def commit(self, **kwargs): self.backend.commit(**kwargs)
[docs] def close(self): if self.loop is not None: self.loop.close() if self.sckt is not None: info('Shutting down server') self.sckt.close()
[docs]class ServerSignalHandler(SignalHandler): def __init__(self, server): super(ServerSignalHandler, self).__init__() self.server = server
[docs] def sigterm(self, signum, frame): self.server.close()
[docs] def sigint(self, signum, frame): self.server.close()
[docs]def start_track_server(protocol, hostname, port, security_layer=None): """Start a track server inside a asyncio loop Parameters ---------- protocol: str URI that defines which backend to forward the request to hostname: str server host name port: int server port to listen to security_layer: str backend used for encryption (only AES is supported) """ security = '' if security_layer is not None: security = f'&security_layer={security_layer}' server = SocketServer(f'socket://{hostname}:{port}?backend={protocol}' + security) _ = ServerSignalHandler(server) try: info('Running Server') server.run_server() except KeyboardInterrupt as e: server.close() raise e except Exception as e: server.close() raise e
if __name__ == '__main__': start_track_server('file:server_test.json', 'localhost', 37382)