Source code for offlax.utils

"""Code for utility functions used in Offlax
"""
from typing import Callable

import gym
import numpy as np

from offlax.replay_buffer import ReplayBuffer


[docs]def generate_offlax_dataset( env: gym.Env, agent: Callable, steps: int, path: str ) -> None: """Generates an offlax dataset for the environment and agent Args: env (gym.Env): An Gym API compatible environment agent (Callable): a function that returns an action given the state of the environment steps (int): number of steps """ step_count = 0 obs = env.reset() obs_v = None action_v = None reward_v = None done_v = None while step_count < steps: action = agent(obs) next_obs, reward, done, info = env.step(action) if obs_v: obs_v = np.hstack([obs_v, action]) else: obs_v = obs if action_v: action_v = np.hstack([action_v, action]) else: action_v = action if reward_v: reward_v = np.hstack([reward_v, action]) else: reward_v = reward if done_v: done_v = np.hstack([done_v, action]) else: done_v = done obs = next_obs ReplayBuffer(obs_v, action_v, reward_v, done_v).dump(path)