Source code for offlax.runner

from typing import Dict

import ray
from ray import air, tune
from ray.air.integrations.wandb import WandbLoggerCallback

__all__ = ["OfflaxRunner"]


# @ray.remote
[docs]class OfflaxRunner: """ Generic runner that encapsulates search, training and inference. Args: algorithm: Offlax Algorithm to run search_space (_type_): Search space of the algorithm. Defaults to the algorithm's search space if unspecified """ def __init__( self, algorithm: str, search_space=None, enable_wandb: bool = True, wandb_kwargs: Dict = {}, *args, **kwargs ): ray.init(*args, **kwargs) self.algorithm = algorithm if search_space is None: self.search_space = self.algorithm.get_search_space() else: self.search_space = search_space self.enable_wandb = enable_wandb self.wandb_kwargs = wandb_kwargs self.callbacks = self._get_callbacks() def _get_callbacks(self): callbacks = None if self.enable_wandb: callbacks = [WandbLoggerCallback(**self.wandb_kwargs)] return callbacks def search(self, environment: str): tuner = tune.Tuner( tune.with_parameters(self.train, environment=environment), param_space=self.search_space, run_config=air.RunConfig(callbacks=self.callbacks), ) results = tuner.fit() return results def train(self, config=None, environment: str = "maze2d-open-v0"): if config is not None: algorithm = self.algorithm.parse_config(config) algorithm.train( rtune=tune, enable_wandb=self.enable_wandb, environment=environment ) else: algorithm = self.algorithm algorithm.train(enable_wandb=self.enable_wandb, environment=environment) def test(self): pass