Source code for offlax.models

from abc import ABC, abstractmethod
from copy import deepcopy
from typing import List, Tuple, Union

import jax
from flax import linen as nn
from flax.core.scope import VariableDict
from jax import numpy as jnp
from omegaconf import OmegaConf
from tensorflow_probability.substrates import jax as tfp
from ray import tune

tfd = tfp.distributions


[docs]class Policy(nn.Module): """Generic implementation of a Policy class""" @abstractmethod def get_action(self, variables: VariableDict, state: jnp.ndarray, *args, **kwargs): raise NotImplementedError @abstractmethod def get_config_dict(self) -> OmegaConf: raise NotImplementedError def save_config_dict(self, *args, **kwargs) -> None: config = self.get_config_dict() return config.save(*args, **kwargs)
[docs]class ActorContinuous(Policy): """Actor for a continuous action space""" hidden_dim: int output_dim: int @nn.compact def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear1_{self.__repr__()}")(x)) x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear2_{self.__repr__()}")(x)) x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear3_{self.__repr__()}")(x)) mu = nn.Dense(self.output_dim, name=f"linear_mu_{self.__repr__()}")(x) log_std = nn.Dense(self.output_dim, name=f"linear_log_std_{self.__repr__()}")(x) return mu, log_std
[docs] def get_action( self, variables: VariableDict, state: jnp.ndarray, key: jax.random.PRNGKey, deterministic: bool = False, return_log_prob: bool = True, ) -> jnp.ndarray: """Returns action for a given state Args: variables (VariableDict): weights of the actor state (jnp.ndarray): state of the environment key (jax.random.PRNGKey): JAX random key deterministic (bool, optional): flag when True, return a deterministic action. Defaults to False. return_log_prob (bool, optional): flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True. Returns: jnp.ndarray: _description_ #TODO: Update this """ mu, log_std = self.apply(variables, state) if deterministic: mu = jax.lax.stop_gradient(mu) return jnp.tanh(mu) # normal distribution normal_dist = tfd.Normal(mu, jnp.exp(log_std)) normal_dist_sample = normal_dist.sample(seed=key) action = jnp.tanh(normal_dist_sample) if return_log_prob: log_prob = normal_dist.log_prob(normal_dist_sample) - jnp.log( 1 - action**2 + 1e-6 ).sum(1, keepdims=True) return action, log_prob return action
def get_config_dict(self, name: str = "") -> OmegaConf: config = OmegaConf.create({}) if name != "": name += "/" config[f"{name}hidden_dim"] = self.hidden_dim config[f"{name}output_dim"] = self.output_dim if name == "": return config return {name: config} def get_search_space(self, prefix: str = ""): if prefix != "": prefix += "/" config = {f"{prefix}hidden_dim": tune.grid_search([32, 64, 128, 256])} return config
[docs]class ActorDiscrete(Policy): """Actor for a continuous action space""" hidden_dim: int output_dim: int @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear1_{self.__repr__()}")(x)) x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear2_{self.__repr__()}")(x)) x = nn.Dense(self.output_dim, name=f"linear3_{self.__repr__()}")(x) return nn.softmax(x)
[docs] def get_action( self, variables: VariableDict, state: jnp.ndarray, key: jax.random.PRNGKey, deterministic: bool = False, return_log_prob: bool = True, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Returns action for a given state Args: variables (VariableDict): weights of the actor state (jnp.ndarray): state of the environment key (jax.random.PRNGKey): JAX random key deterministic (bool, optional): flag when True, return a deterministic action. Defaults to False. return_log_prob (bool, optional): flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True. Returns: Tuple[jnp.ndarray, jnp.ndarray]: _description_ #TODO: Update this """ action_probability = self.apply(variables, state) categorical_distribution = tfd.Categorical(action_probability) action = categorical_distribution.sample(seed=key) if deterministic: return action if return_log_prob: log_action_probability = jnp.log( jnp.asarray(action_probability == 0.0).astype("float") * 1e-8 + action_probability ) return action, action_probability, log_action_probability return action_probability
def get_config_dict(self, name: str = "") -> OmegaConf: config = OmegaConf.create({}) if name != "": name += "/" config[f"{name}hidden_dim"] = self.hidden_dim config[f"{name}output_dim"] = self.output_dim if name == "": return config return {name: config} def get_search_space(self, prefix: str = ""): if prefix != "": prefix += "/" config = {f"{prefix}hidden_dim": tune.grid_search([32, 64, 128, 256])} return config
[docs]class Critic(Policy): """Critic for an Actor-Critic based algorithm""" hidden_dim: int output_dim: int @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear1_{self.__repr__()}")(x)) x = nn.relu(nn.Dense(self.hidden_dim, name=f"linear2_{self.__repr__()}")(x)) return nn.Dense(self.output_dim, name=f"linear3_{self.__repr__()}")(x) def get_action(self, variables: VariableDict, state: jnp.ndarray) -> jnp.ndarray: return self.apply(variables, state) def get_config_dict(self, name: str = "") -> OmegaConf: config = OmegaConf.create({}) if name != "": name += "/" config[f"{name}hidden_dim"] = self.hidden_dim config[f"{name}output_dim"] = self.output_dim if name == "": return config return {name: config} def get_search_space(self, prefix: str = ""): if prefix != "": prefix += "/" config = {f"{prefix}hidden_dim": tune.grid_search([32, 64, 128, 256])} return config