Source code for track.persistence.protocol

from track.structure import Project, TrialGroup, Trial
from track.aggregators.aggregator import Aggregator
from track.aggregators.aggregator import StatAggregator
from track.aggregators.aggregator import ValueAggregator
from typing import Callable, Optional, List

value_aggregator = ValueAggregator.lazy()


[docs]class Protocol:
[docs] def log_trial_start(self, trial): """Send the trial start signal Parameters ---------- trial: Trial reference to the trial being started """ raise NotImplementedError()
[docs] def log_trial_finish(self, trial, exc_type, exc_val, exc_tb): """Send the trial end signal Parameters ---------- trial: Trial reference to the trial that finished """ raise NotImplementedError()
[docs] def log_trial_chrono_start(self, trial, name: str, aggregator: Callable[[], Aggregator] = StatAggregator.lazy(1), start_callback=None, end_callback=None): """Send the start signal for an event Parameters ---------- trial: Trial trial sending the event name: str name of the event aggregator: Aggregator container used to accumulate elapsed time start_callback: Callable function called at start time end_callback: Callable function called at the end """ raise NotImplementedError()
[docs] def log_trial_chrono_finish(self, trial, name, exc_type, exc_val, exc_tb): """Send the end signal for an event Parameters ---------- trial: Trial trial sending the event name: str name of the event exc_type: Exception object exec_val Exception value exc_tb: Traceback """ raise NotImplementedError()
[docs] def log_trial_arguments(self, trial: Trial, **kwargs): """Save the arguments a trail Parameters ---------- trial: Trial trial for which the arguments are for kwargs: key value pair of arguments """ raise NotImplementedError()
[docs] def log_trial_metadata(self, trial: Trial, aggregator: Callable[[], Aggregator] = value_aggregator, **kwargs): """Save metadata for a given trials Parameters ---------- trial: Trial trial reference kwargs: key value pair of the data to save """ raise NotImplementedError()
[docs] def log_trial_metrics(self, trial: Trial, step: any = None, aggregator: Callable[[], Aggregator] = None, **kwargs): """Save metrics for a given trials Parameters ---------- trial: Trial trial reference kwargs: key value pair of the data to save """ raise NotImplementedError()
[docs] def set_trial_status(self, trial: Trial, status, error=None): """Change trial status Parameters ---------- trial: Trial trial reference status: new status to update the trial too error: in case the user is changing to a state representing an error it can also provide an error identification string """ raise NotImplementedError()
[docs] def add_trial_tags(self, trial, **kwargs): """Add tags to a trial Parameters ---------- trial: Trial trial reference kwargs: key value pair of the data to save """ raise NotImplementedError()
# Object Creation
[docs] def get_project(self, project: Project) -> Optional[Project]: """Fetch a project according to the given definition Parameters ---------- project: Project project definition used for the lookup Returns ------- returns a project object or None """ raise NotImplementedError()
[docs] def new_project(self, project: Project): """Insert a new project Parameters ---------- project: Project project definition used for the insert """ raise NotImplementedError()
[docs] def get_trial_group(self, group: TrialGroup) -> Optional[TrialGroup]: """Fetch a group according to a given definition Parameters ---------- group: TrialGroup group definition used for the lookup Returns ------- returns a grouo """ raise NotImplementedError()
[docs] def new_trial_group(self, group: TrialGroup): """Create a new group Parameters ---------- group: TrialGroup group definition used for the insert """ raise NotImplementedError()
[docs] def add_project_trial(self, project: Project, trial: Trial): """Add a trial to a project""" raise NotImplementedError()
[docs] def add_group_trial(self, group: TrialGroup, trial: Trial): """Add a trial to a group""" raise NotImplementedError()
[docs] def commit(self, **kwargs): """Forces to persist the change""" raise NotImplementedError()
[docs] def get_trial(self, trial: Trial) -> List[Trial]: """Fetch trials according to a given definition Parameters ---------- trial: Trial trial definition used for the lookup """ raise NotImplementedError()
[docs] def new_trial(self, trial: Trial, auto_increment=False): """Insert a new trial Parameters ---------- trial: Trial trial definition used for the insert auto_increment: bool If trial exist increment revision number Returns ------- Returns None if Trial already exists and auto_increment is False """ raise NotImplementedError()
[docs] def fetch_and_update_trial(self, query, attr, *args, **kwargs): """Fetch and update a single trial Parameters ---------- query: Dict dictionary to fetch trials attr: str name of the update function to call on each selected trials *args: additional positional arguments for the attr function **kwargs: addtional keyword arguments for the attr function Returns ------- returns the modified trial """ raise NotImplementedError()
[docs] def fetch_and_update_group(self, query, attr, *args, **kwargs): """Fetch and update a single group Parameters ---------- query: Dict dictionary to fetch groups attr: str name of the update function to call on each selected group *args: additional positional arguments for the attr function **kwargs: addtional keyword arguments for the attr function Returns ------- returns the modified group """ raise NotImplementedError()
[docs] def fetch_trials(self, query) -> List[Trial]: """Fetch trials according to a given query""" raise NotImplementedError()
[docs] def fetch_groups(self, query): """Fetch groups according to a given query""" raise NotImplementedError()
[docs] def fetch_projects(self, query): """Fetch projects according to a given query""" raise NotImplementedError()