import time
import logging
from filelock import FileLock, logger as file_lock_logger
from typing import Callable
from threading import RLock
from track.configuration import options
from track.utils import ItemNotFound
from track.utils.signal import SignalHandler
from track.utils.log import error, warning, debug
from track.structure import Project, Trial, TrialGroup
from track.persistence.protocol import Protocol
from track.persistence.storage import load_database, LocalStorage
from track.persistence.utils import parse_uri
from track.containers.types import float32
from track.aggregators.aggregator import Aggregator
from track.aggregators.aggregator import RingAggregator
from track.aggregators.aggregator import StatAggregator
from track.aggregators.aggregator import ValueAggregator
from track.aggregators.aggregator import TimeSeriesAggregator
value_aggregator = ValueAggregator.lazy()
ring_aggregator = RingAggregator.lazy(10, float32)
stat_aggregator = StatAggregator.lazy(1)
ts_aggregator = TimeSeriesAggregator.lazy()
file_lock_logger().setLevel(logging.ERROR)
def _make_container(step, aggregator):
if step is None:
if aggregator is None:
# favor ts aggregator because it has an option to cut the TS for printing purposes
return ts_aggregator()
return aggregator()
else:
return dict()
class _NoLockLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def acquire(self, *args, **kwargs):
return self
[docs]class MultiLock:
def __init__(self, obj):
self.obj = obj
def __enter__(self):
self.obj.thread_lock.acquire(timeout=30)
self.obj.lock.acquire()
self.obj.lock_guard_depth += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.obj.thread_lock.release()
self.obj.lock.release()
self.obj.lock_guard_depth -= 1
[docs]def make_lock(name, eager):
if eager:
return FileLock(name, timeout=options('log.backend.lock_timeout', 30))
return _NoLockLock()
[docs]class ConcurrentWrite(Exception):
def __init__(self, msg):
super(ConcurrentWrite, self).__init__(msg)
[docs]def lock_guard(readonly, atomic=False):
"""Protect a function call with a lock. reload the database before the action and save it afterwards"""
def lock_guard_decorator(fun):
def _lock_guard(self, *args, **kwargs):
with MultiLock(self):
# only reload database if path is not none and the lock is not already owned
if self.path and self.eager and self.lock_guard_depth == 1:
# debug(f'Reload database for `{fun.__name__}`')
self.storage = load_database(self.path)
val = fun(self, *args, **kwargs)
if self.eager and not readonly:
# debug(f'Save database for `{fun.__name__}`')
self.commit()
return val
return _lock_guard
return lock_guard_decorator
lock_write = lock_guard(readonly=False)
lock_atomic_write = lock_guard(readonly=False, atomic=True)
lock_read = lock_guard(readonly=True)
[docs]class LockFileRemover(SignalHandler):
def __init__(self, filename):
super(LockFileRemover, self).__init__()
self.file_name = filename
[docs] def remove(self):
import os
if os.path.exists(self.file_name):
os.remove(self.file_name)
[docs] def sigterm(self, signum, frame):
self.remove()
[docs] def sigint(self, signum, frame):
self.remove()
[docs] def atexit(self):
self.remove()
[docs]class FileProtocol(Protocol):
"""Local File storage to manage experiments
Parameters
----------
uri: str
resource to use to store the experiment `file://my_file.json`
strict: bool
forces the storage to be correct.
if we use the file protocol as an in-memory storage we might get some inconsistencies
we can use this flag to ignore them
eager: bool
eagerly update the underlying files. This is necessary if multiple processes are reading from the file
"""
def __init__(self, uri, strict=True, eager=True):
uri = parse_uri(uri)
# file:test.json
path = uri.get('path')
if not path:
# file://test.json
path = uri.get('address')
self.path = path
self.storage: LocalStorage = load_database(path)
self.chronos = {}
self.strict = strict
self.eager = eager
self.signal_handler = LockFileRemover(f'{path}.lock')
self.lock = make_lock(f'{path}.lock', eager)
self.lock_guard_depth = 0
self.thread_lock = RLock()
def _inc_trial(self, trial):
trial.metadata['_update_count'] = trial.metadata.get('_update_count', 0) + 1
trial.metadata['_last_change'] = time.time()
[docs] @lock_write
def log_trial_start(self, trial):
ntrial = self.storage.objects.get(trial.uid)
acc = ValueAggregator()
trial.chronos['runtime'] = acc
self.chronos['runtime'] = time.time()
ntrial.chronos['runtime'] = acc
self._inc_trial(ntrial)
return trial
[docs] @lock_write
def log_trial_finish(self, trial, exc_type, exc_val, exc_tb):
ntrial = self.storage.objects.get(trial.uid)
start_time = self.chronos['runtime']
acc = trial.chronos['runtime']
acc.append(time.time() - start_time)
ntrial.chronos['runtime'] = acc
self._inc_trial(ntrial)
[docs] @lock_write
def log_trial_chrono_start(self, trial, name: str, aggregator: Callable[[], Aggregator] = StatAggregator.lazy(1),
start_callback=None,
end_callback=None):
ntrial = self.storage.objects.get(trial.uid)
agg = trial.chronos.get(name)
if agg is None:
agg = aggregator()
trial.chronos[name] = agg
self.chronos[name] = time.time()
ntrial.chronos[name] = agg
self._inc_trial(ntrial)
[docs] @lock_write
def log_trial_chrono_finish(self, trial, name, exc_type, exc_val, exc_tb):
ntrial = self.storage.objects.get(trial.uid)
start_time = self.chronos[name]
acc = trial.chronos[name]
acc.append(time.time() - start_time)
ntrial.chronos[name] = acc
self._inc_trial(ntrial)
[docs] @lock_write
def log_trial_metrics(self, trial: Trial, step: any = None, aggregator: Callable[[], Aggregator] = None, **kwargs):
ntrial = self.storage.objects.get(trial.uid)
for k, v in kwargs.items():
container = trial.metrics.get(k)
if container is None:
container = _make_container(step, aggregator)
trial.metrics[k] = container
if step is not None and isinstance(container, dict):
container[step] = v
elif step:
container.append((step, v))
else:
container.append(v)
ntrial.metrics.update(trial.metrics)
self._inc_trial(ntrial)
[docs] @lock_write
def log_trial_arguments(self, trial, **kwargs):
trial = self.storage.objects.get(trial.uid)
trial.parameters.update(kwargs)
self._inc_trial(trial)
# Object Creation
[docs] @lock_read
def get_project(self, project: Project):
debug(f'look for (project: {project.name})')
return self.storage.objects.get(project.uid)
[docs] @lock_write
def new_project(self, project: Project):
debug(f'create new (project: {project.name})')
if project.uid in self.storage.objects:
error(f'Cannot insert project; (uid: {project.uid}) already exists!')
return self.get_project(project)
self.storage.objects[project.uid] = project
self.storage.project_names[project.name] = project.uid
self.storage.projects.add(project.uid)
return project
[docs] @lock_read
def get_trial_group(self, group: TrialGroup):
return self.storage.objects.get(group.uid)
[docs] @lock_write
def new_trial_group(self, group: TrialGroup):
debug(f'create new (group: {group.name})')
if group.uid in self.storage.objects:
error(f'Cannot insert group; (uid: {group.uid}) already exists!')
return
project = self.storage.objects.get(group.project_id)
if self.strict:
assert project is not None, 'Cannot create a group without an associated project'
project.groups.add(group)
self.storage.objects[group.uid] = group
self.storage.groups.add(group.uid)
self.storage.group_names[group.name] = group.uid
return group
[docs] @lock_read
def get_trial(self, trial: Trial):
trials = []
if trial.uid in self.storage.objects:
trial_hash = trial.hash
for k, obj in self.storage.objects.items():
if k.startswith(trial_hash):
trials.append(obj)
return trials
return None
[docs] @lock_write
def new_trial(self, trial: Trial, auto_increment=False):
debug(f'create new (trial: {trial.uid})')
if trial.uid in self.storage.objects:
if not auto_increment:
debug('Found already existing trial')
return None
trials = self.get_trial(trial)
max_rev = 0
for t in trials:
max_rev = max(max_rev, t.revision)
warning(f'Trial was already completed. Increasing revision number (rev={max_rev + 1})')
trial.revision = max_rev + 1
trial._hash = None
self.storage.objects[trial.uid] = trial
self.storage.trials.add(trial.uid)
if trial.project_id is not None:
project = self.storage.objects.get(trial.project_id)
if project is not None or self.strict:
project.trials.add(trial)
else:
warning('Orphan trial')
if trial.group_id is not None:
group = self.storage.objects.get(trial.group_id)
if group is not None or self.strict:
group.trials.add(trial.uid)
trial.metadata['_update_count'] = 0
return trial
[docs] @lock_write
def add_project_trial(self, project, trial):
assert project is not None, 'Project cant be None'
trial.project_id = project.uid
project.trials.add(trial)
[docs] @lock_write
def add_group_trial(self, group, trial):
if group is None and not self.strict:
return
trial.group_id = group.uid
group.trials.add(trial.uid)
[docs] def commit(self, file_name_override=None, **kwargs):
if self.path:
with self.lock.acquire():
self.storage.commit(file_name_override=file_name_override, **kwargs)
else:
warning('Path undefined!')
@lock_read
def _fetch_objects(self, objects, query, strict=False):
matching_objects = []
for obj_id in objects:
obj = self.storage.objects.get(obj_id)
if obj is None:
err = f'stale trial (id: {obj_id}) something is wrong'
if strict:
raise RuntimeError(err)
else:
warning(err)
continue
is_selected = execute_query(obj, query)
if is_selected:
matching_objects.append(obj)
return matching_objects
[docs] @lock_write
def fetch_and_update_trial(self, query, attr, *args, **kwargs):
trials = self.fetch_trials(query)
if len(trials) <= 0:
raise ItemNotFound(f'Expected one or more trial got {len(trials)} trials with query `{query}`')
fun = getattr(self, attr)
if not trials:
return None
fun(trials[0], *args, **kwargs)
return trials[0]
[docs] @lock_atomic_write
def set_trial_status(self, trial, status, error=None):
trial = self.storage.objects.get(trial.uid)
trial.status = status
if error is not None:
trial.errors.append(str(error))
self._inc_trial(trial)
[docs] @lock_write
def fetch_and_update_group(self, query, attr, *args, **kwargs):
groups = self.fetch_groups(query)
if len(groups) <= 0:
raise ItemNotFound(f'Expected one or more group got {len(groups)} groups with query `{query}`')
fun = getattr(self, attr)
if not groups:
return None
fun(groups[0], *args, **kwargs)
return groups[0]
[docs] @lock_read
def fetch_trials(self, query=None):
return self._fetch_objects(self.storage.trials, query)
[docs] @lock_read
def fetch_groups(self, query=None):
return self._fetch_objects(self.storage.groups, query)
[docs] @lock_read
def fetch_projects(self, query=None):
return self._fetch_objects(self.storage.projects, query)
def _get_attribute(obj, attrs):
attribute = getattr(obj, attrs[0])
for key in attrs[1:]:
attribute = attribute.get(key)
if attribute is None:
return None
return attribute
[docs]def execute_query(obj, query):
"""Check if the object `obj` matches the query.
The query is a dictionary specifying constraint on each of the object attributes
"""
if query is None:
return True
is_selected = True
# a query can be a dict of a list of conditions
# allowing a list enable users to make sure the conditions are executed in a specific order
# this can be used to speed up query. You can put the most strict condition first to reduce the number of checks
# we have to do to select a query
items = None
if isinstance(query, dict):
items = query.items()
else:
items = list(query)
for attrs, condition in items:
attrs = attrs.split('.')
if not hasattr(obj, attrs[0]):
raise RuntimeError(f'(obj: {type(obj)}) has no (attribute: {attrs[0]})')
# This is a complex query that needs to be further processed
if isinstance(condition, dict):
if len(condition) == 1:
fun_name, args = list(condition.items())[0]
fun = _query_fun.get(fun_name)
if fun is None:
raise RuntimeError(f'(function: {fun_name}) is not understood')
is_selected &= fun(obj, attrs, args)
else:
raise RuntimeError(f'(query: {query}) was not understood')
# this is a simple value
else:
is_selected &= _get_attribute(obj, attrs) == condition
# shortcut
if not is_selected:
return False
return is_selected
[docs]def query_in(obj, attrs, choices):
cnv = lambda x: x
if isinstance(choices[0], str):
cnv = str
return cnv(_get_attribute(obj, attrs)) in choices
[docs]def query_ne(obj, attrs, val):
v = _get_attribute(obj, attrs)
return v != val
[docs]def query_lte(obj, attrs, val):
return _get_attribute(obj, attrs) <= val
[docs]def query_gt(obj, attrs, val):
return _get_attribute(obj, attrs) > val
_query_fun = {
'$in': query_in,
'$ne': query_ne,
'$lte': query_lte,
'$gt': query_gt
}