Source code for track.persistence.storage

import os
import json
from dataclasses import dataclass, field
from typing import Dict, Set
import tempfile
from uuid import UUID

from track.utils.log import error, warning, debug
from track.structure import Project, Trial, TrialGroup
from track.serialization import from_json, to_json
from track.aggregators.aggregator import StatAggregator


_print_warning_once = set()


[docs]@dataclass class LocalStorage: # Main storage target_file: str = None _objects: Dict[UUID, any] = field(default_factory=dict) # Indexes _projects: Set[UUID] = field(default_factory=set) _groups: Set[UUID] = field(default_factory=set) _trials: Set[UUID] = field(default_factory=set) _project_names: Dict[str, UUID] = field(default_factory=dict) _group_names: Dict[str, UUID] = field(default_factory=dict) _trial_names: Dict[str, UUID] = field(default_factory=dict) _old_rev_tags: Dict[str, int] = field(default_factory=dict)
[docs] def get_previous_version_tag(self, obj): return self._old_rev_tags.get(obj.uid, 0)
[docs] def get_current_version_tag(self, obj): if isinstance(obj, Trial): return obj.metadata.get('_update_count', 0) else: print(obj) return None
@property def objects(self) -> Dict[UUID, any]: return self._objects # Indexes @property def projects(self) -> Set[UUID]: return self._projects @property def groups(self) -> Set[UUID]: return self._groups @property def trials(self) -> Set[UUID]: return self._trials @property def project_names(self) -> Dict[str, UUID]: return self._project_names @property def group_names(self) -> Dict[str, UUID]: return self._group_names
[docs] def commit(self, file_name_override=None, **kwargs): if file_name_override is None: file_name_override = self.target_file if file_name_override is None: debug('No output file target') return None # only save top level projects objects = [] for uid in self._projects: objects.append(to_json(self._objects[uid])) # print(json.dumps(objects[-1], indent=2)) fd, file_name = tempfile.mkstemp(prefix='track_uncommitted_', dir=os.getcwd()) with os.fdopen(fd, 'w') as output: json.dump(objects, output, indent=2) output.close() # mv is kind of atomic so this prevent generating half generated files os.rename(file_name, file_name_override)
# shutil.move(file_name, file_name_override) def _insert_object(self, obj): self._objects[obj.uid] = obj if isinstance(obj, Trial): self._trials.add(obj.uid) elif isinstance(obj, TrialGroup): self._groups.add(obj.uid) elif isinstance(obj, Project): self._projects.add(obj.uid) def _update_object(self, obj, new): if isinstance(obj, Trial): # this is for atomic updates obj_oversion = obj.metadata.get('_update_count', 0) obj_nversion = new.metadata.get('_update_count', 0) self._old_rev_tags[obj.uid] = obj_oversion # the object has not changed if obj_oversion == obj_nversion: return if obj_oversion > obj_nversion: raise RuntimeError(f'Cannot update object with older version {obj_oversion} > {obj_nversion}!') obj.status = new.status obj.metrics.update(new.metrics) obj.parameters.update(new.parameters) # Chrono are special they do not get updated if you are the worker for name, val in new.chronos.items(): chrono = obj.chronos.get(name) if chrono is None: obj.chronos[name] = val elif isinstance(chrono, StatAggregator): break else: chrono.update(val) elif isinstance(obj, Project): obj.trials.update(set(new.trials)) elif isinstance(obj, TrialGroup): obj.trials.update(set(new.trials)) obj.name = new.name obj.description = new.description obj.metadata = new.metadata
[docs] def reload(self, filename=None): """Reload storage and discard current objects""" debug('reloading files') if filename is None: filename = self.target_file new_storage = load_database(filename) self._objects = new_storage._objects self._projects = new_storage._projects self._groups = new_storage._groups self._trials = new_storage._trials self._project_names = new_storage._project_names self._group_names = new_storage._group_names self._trial_names = new_storage._trial_names
[docs] def smart_reload(self, filename=None): """Updates current objects with new data""" debug('smart reloading files') if filename is None: filename = self.target_file new_storage = load_database(filename) for uid, obj in new_storage.objects.items(): old_obj = self.objects.get(uid) if old_obj is not None: self._update_object(old_obj, obj) else: self._insert_object(obj)
[docs]def load_database(json_name): global _print_warning_once if json_name is None: return LocalStorage() if not os.path.exists(json_name): if json_name not in _print_warning_once: warning(f'Local Storage was not found at {json_name}') _print_warning_once.add(json_name) return LocalStorage(target_file=json_name) with open(json_name, 'r') as file: objects = json.load(file) db = dict() projects = set() project_names = dict() groups = set() group_names = dict() trials = set() trial_names = dict() for item in objects: obj = from_json(item) if obj.uid in db: raise RuntimeError('Should be unreachable!') db[obj.uid] = obj if isinstance(obj, Project): projects.add(obj.uid) if obj.name in project_names: error('Non unique project names are not supported') if obj.name is not None: project_names[obj.name] = obj.uid for trial in obj.trials: db[trial.uid] = trial trials.add(trial.uid) for group in obj.groups: db[group.uid] = group groups.add(group.uid) elif isinstance(obj, Trial): trials.add(obj.uid) if obj.name is not None: trial_names[obj.name] = obj.uid elif isinstance(obj, TrialGroup): groups.add(obj.uid) if obj.name is not None: group_names[obj.name] = obj.uid return LocalStorage(json_name, db, projects, groups, trials, project_names, group_names, trial_names)