Source code for track.structure

""" hold basic data type classes that all backends need to implement """

from typing import List, Dict, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
from track.versioning import compute_hash


[docs]class Status(Enum): CreatedGroup = 0 # was created nothing is planed for it at the moment RunningGroup = 100 # is running Running = 101 ErrorGroup = 200 # Not Running; because: Interrupted = 201 # -> SIGTERM Exception = 202 # -> Exception was raised Broken = 203 # -> multiple error occured, will not be retried FinishedGroup = 300 # Not Running because: Suspended = 301 # -> was suspended by the user Completed = 302 # -> has finished running def __str__(self): return self.name.lower()
_STATUS_INT = set(map(lambda x: x.value, Status.__members__.values())) _STATUS_STR = set(map(lambda x: x.name, Status.__members__.values()))
[docs]class CustomStatus: def __init__(self, name, value): self._name = name self._value = value @property def name(self): return self._name @property def value(self): return self._value def __repr__(self): return f'CStatus<{self.name}>' def __str__(self): return self.name.lower() def __eq__(self, other): if isinstance(other, dict): return self.value == other.get('value') and self.name == other.get('name') if isinstance(other, str): return self.name == other return self.value == other.value and self.name == other.name
[docs]def status(name=None, value=None): if name is not None: if name in _STATUS_STR: return Status.__members__[name] return CustomStatus(name, value) if value is not None: if value in _STATUS_INT: return Status(value) return CustomStatus(name, value)
[docs]@dataclass class Trial: """ A single training run """ @property def uid(self) -> str: return f'{self.hash}_{self.revision}' @uid.setter def uid(self, value): h, r = value.rsplit('_', maxsplit=1) self._hash = h self.revision = r @property def hash(self): if self._hash is None: self._hash = self.compute_hash() return self._hash @hash.setter def hash(self, v): self._hash = v
[docs] def compute_hash(self) -> str: return compute_hash(self.name, self.version, **self.parameters)
_hash: str = None revision: int = 0 # if uid is a duplicate rev += 1 name: Optional[str] = None description: Optional[str] = None tags: Dict[str, any] = field(default_factory=dict) version: Optional[str] = None group_id: Optional[int] = None project_id: Optional[int] = None # Arguments and parameters parameters: Dict[str, any] = field(default_factory=dict) # Meta data about the trial metadata: Dict[str, any] = field(default_factory=dict) # Training metrics metrics: Dict[str, any] = field(default_factory=dict) # Timers are saved here chronos: Dict[str, any] = field(default_factory=dict) status: Optional[Status] = Status.CreatedGroup # List of errors that occurred during the training errors: List[str] = field(default_factory=list) def __hash__(self): return hash(self.uid) def __eq__(self, other): return hash(other) == hash(self.uid)
[docs]@dataclass class TrialGroup: """ Namespace / Set of trials """
[docs] def compute_uid(self) -> str: assert self.project_id is not None, f'Trial Group (name: {self.name}) needs to be associated with a project' assert self.name is not None, f'Trial Group for (project: {self.project_id}) has no name!' return compute_hash(self.name, self.project_id)
@property def uid(self): if self._uid is None: self._uid = self.compute_uid() return self._uid _uid: str = None name: Optional[str] = None description: Optional[str] = None metadata: Dict[str, any] = field(default_factory=dict) trials: Set[Trial] = field(default_factory=set) project_id: Optional[int] = None def __hash__(self): return hash(self.uid) def __eq__(self, other): return hash(other) == hash(self.uid)
[docs]@dataclass class Project: """ Set of Trial Groups & trials If projects define tags than all children inherit those tags. children cannot override the tag of a parent """
[docs] def compute_uid(self) -> str: assert self.name is not None, f'Project need a name!' return self.name
@property def uid(self): if self._uid is None: self._uid = self.compute_uid() return self._uid _uid: str = None name: Optional[str] = None description: Optional[str] = None metadata: Dict[str, any] = field(default_factory=dict) groups: Set[TrialGroup] = field(default_factory=set) trials: Set[Trial] = field(default_factory=set) def __hash__(self): return hash(self.uid) def __eq__(self, other): return hash(other) == hash(self.uid)
_current_project = None _current_trial = None
[docs]def get_current_project(): return _current_project
[docs]def get_current_trial(): return _current_trial
[docs]def set_current_project(project): global _current_project _current_project = project return _current_project
[docs]def set_current_trial(trial): global _current_trial _current_trial = trial return _current_trial
if __name__ == '__main__': print(status(name='Running')) print(status(value=101)) print(status(name='test', value=23))