from __future__ import annotations
from copy import deepcopy
from functools import partial
from typing import List, Tuple, Dict
import d4rl
import gym
import jax
import numpy as np
import optax
import wandb
from jax import numpy as jnp
from flax import linen as nn
from omegaconf import OmegaConf
from ray import tune
from ray.air import session
try:
from tqdm import TqdmExperimentalWarning
from tqdm.rich import tqdm
except ImportError:
# Rich not installed, we only throw an error
# if the progress bar is used
tqdm = None
from offlax.models import ActorDiscrete, ActorContinuous, Critic
from offlax.runner import OfflaxRunner
def sample(trajectories: Dict, rng: jax.random.PRNGKey, batch_size: int):
indices = np.random.randint(trajectories["observations"].shape[0], size=batch_size)
sample_trajectory = {}
for key in trajectories.keys():
sample_trajectory[key] = jax.tree_util.tree_map(
jax.device_put, trajectories[key][indices, ...]
)
return sample_trajectory
[docs]class CQLDiscrete:
"""Implementation of Conservative Q Learning (CQL) algorithm.
Paper: https://arxiv.org/abs/2006.04779
"""
def __init__(
self,
rng: jax.random.PRNGKey,
actor: ActorDiscrete,
critic: Critic,
state_dims: List[int],
action_dims: int,
gamma: float,
tau: float,
):
self.rng = rng
self.actor = actor
self.actor_variables = self.actor.init(rng, jnp.ones((1, state_dims)))
self.actor_optimizer = optax.adam(1e-3)
self.actor_optimizer_variables = self.actor_optimizer.init(self.actor_variables)
self.critic1 = critic
self.rng, _ = jax.random.split(self.rng)
self.critic1_variables = self.critic1.init(rng, jnp.ones((1, state_dims)))
self.critic1_optimizer = optax.adam(1e-3)
self.critic1_optimizer_variables = self.critic1_optimizer.init(
self.critic1_variables
)
self.critic_target1 = deepcopy(critic)
self.critic_target1_variables = deepcopy(self.critic1_variables)
self.critic2 = critic
self.rng, _ = jax.random.split(self.rng)
self.critic2_variables = self.critic2.init(rng, jnp.ones((1, state_dims)))
self.critic2_optimizer = optax.adam(1e-3)
self.critic2_optimizer_variables = self.critic2_optimizer.init(
self.critic2_variables
)
self.critic_target2 = deepcopy(critic)
self.critic_target2_variables = deepcopy(self.critic2_variables)
self.alpha = jnp.zeros((1))
self.alpha_optimizer = optax.adam(1e-3)
self.alpha_optimizer_variables = self.alpha_optimizer.init(self.alpha)
self.target_entropy = -float(action_dims)
self.gamma = gamma
self.tau = tau
self.state_dims = state_dims
self.action_dims = action_dims
@jax.jit
def get_action(
self, state: jnp.ndarray, train: bool = False, rng: jax.random.PRNGKey = None
) -> jnp.ndarray:
state = jax.lax.stop_gradient(state)
assert rng is not None
self.rng, key = jax.random.split(rng)
action = self.actor.get_action(
self.actor_variables, state, deterministic=not train, key=key
)
action = jax.lax.stop_gradient(action)
return action
@partial(jax.jit, static_argnums=(0,))
def get_actor_loss(
self,
states: jnp.ndarray,
actor_variables,
critic1_variables,
critic2_variables,
alpha: float,
loss_key,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
actions, action_probabilities, log_preds_actions = self.actor.get_action(
actor_variables, states, loss_key, return_log_prob=True
)
q1 = self.critic1.apply(critic1_variables, states)
q2 = self.critic2.apply(critic2_variables, states)
min_q = jnp.minimum(q1, q2)
actor_loss = jnp.mean(
jnp.sum(action_probabilities * (alpha * log_preds_actions - min_q), axis=1)
)
log_action_sum = jnp.sum(log_preds_actions * action_probabilities)
return actor_loss, log_action_sum
@partial(jax.jit, static_argnums=(0,))
def get_alpha_loss(self, alpha: jnp.ndarray, log_preds: jnp.ndarray) -> jnp.ndarray:
return -(alpha * jax.lax.stop_gradient(log_preds + self.target_entropy)).mean()
@partial(jax.jit, static_argnums=(0,))
def get_critic_loss(
self,
states: jnp.ndarray,
next_states: jnp.ndarray,
rewards: jnp.ndarray,
dones: jnp.ndarray,
critic1_variables,
critic2_variables,
actor_key,
):
action, action_probs, log_prob_sum = self.actor.get_action(
self.actor_variables,
next_states,
actor_key,
deterministic=False,
return_log_prob=True,
)
q_target_1_next = jax.lax.stop_gradient(
self.critic_target1.apply(self.critic_target1_variables, next_states)
)
q_target_2_next = jax.lax.stop_gradient(
self.critic_target2.apply(self.critic_target2_variables, next_states)
)
q_target_next = jax.lax.stop_gradient(
action_probs
* (
jnp.minimum(q_target_1_next, q_target_2_next)
- self.alpha * log_prob_sum
)
)
q_targets = jax.lax.stop_gradient(
rewards
+ (self.gamma * (1 - dones) * jnp.expand_dims(q_target_next.sum(1), 1))
)
q1 = self.critic1.apply(critic1_variables, states)
q2 = self.critic2.apply(critic2_variables, states)
q1_ = jnp.take(q1, action.astype("long"), 1)
q2_ = jnp.take(q2, action.astype("long"), 1)
critic1_loss = 0.5 * jnp.square(q1_ - q_targets)
critic2_loss = 0.5 * jnp.square(q2_ - q_targets)
cql1_scaled_loss = jnp.log(jnp.sum(jnp.exp(q1), 1))
cql2_scaled_loss = jnp.log(jnp.sum(jnp.exp(q2), 1))
total_c1_loss = (critic1_loss).sum() + (cql1_scaled_loss).sum()
total_c2_loss = (critic2_loss).sum() + (cql2_scaled_loss).sum()
return total_c1_loss, total_c2_loss
def step(self, experience_batch):
states, _, rewards, next_states, dones = experience_batch
# Calculate Actor loss and update actor variables
alpha = deepcopy(self.alpha)
self.rng, key = jax.random.split(self.rng)
(actor_loss, log_preds_actor), actor_gradients = jax.value_and_grad(
self.get_actor_loss, has_aux=True, argnums=1
)(
states,
self.actor_variables,
self.critic1_variables,
self.critic2_variables,
alpha,
key,
)
actor_updates, self.actor_optimizer_variables = self.actor_optimizer.update(
actor_gradients, self.actor_optimizer_variables
)
self.actor_variables = optax.apply_updates(self.actor_variables, actor_updates)
# Update alpha
alpha_loss, alpha_gradients = jax.value_and_grad(self.get_alpha_loss)(
self.alpha, log_preds_actor
)
alpha_updates, self.alpha_optimizer_variables = self.alpha_optimizer.update(
alpha_gradients, self.alpha_optimizer_variables, self.alpha
)
self.alpha = optax.apply_updates(self.alpha, alpha_updates)
# Update critics
self.rng, key = jax.random.split(self.rng)
(total_c1_loss, total_c2_loss), (
critic1_gradients,
critic2_gradients,
) = jax.value_and_grad(self.get_critic_loss, has_aux=True, argnums=[4, 5])(
states,
next_states,
rewards,
dones,
self.critic1_variables,
self.critic2_variables,
key,
)
(
critic1_updates,
self.critic1_optimizer_variables,
) = self.critic1_optimizer.update(
critic1_gradients, self.critic1_optimizer_variables, self.critic1_variables
)
self.critic1_variables = optax.apply_updates(
self.critic1_variables, critic1_updates
)
(
critic2_updates,
self.critic2_optimizer_variables,
) = self.critic2_optimizer.update(
critic2_gradients, self.critic2_optimizer_variables, self.critic2_variables
)
self.critic2_variables = optax.apply_updates(
self.critic2_variables, critic2_updates
)
# Update target critics
self.critic_target1_variables = jax.tree_map(
lambda p, target_p: p * self.tau + target_p * (1 - self.tau),
self.critic1_variables,
self.critic_target1_variables,
)
self.critic_target2_variables = jax.tree_map(
lambda p, target_p: p * self.tau + target_p * (1 - self.tau),
self.critic2_variables,
self.critic_target2_variables,
)
return total_c1_loss, total_c2_loss, alpha_loss, actor_loss
def get_config_dict(self):
config = self.actor.get_config_dict("actor")
config.update(self.critic1.get_config_dict("critic"))
config["gamma"] = self.gamma
config["tau"] = self.tau
config["seed"] = int(self.rng[0])
config["state_dims"] = self.state_dims
config["action_dims"] = self.action_dims
config["iterations"] = 1e5
config["batch_size"] = 128
return config
def get_search_space(self):
config = self.actor.get_search_space("actor")
config.update(self.critic1.get_search_space("critic"))
config["gamma"] = tune.uniform(0.95, 0.99)
config["tau"] = tune.uniform(0.95, 0.99)
config["seed"] = tune.grid_search([40, 41, 42, 43, 44, 45])
config["state_dims"] = self.state_dims
config["action_dims"] = self.action_dims
config["iterations"] = tune.grid_search([1e5])
config["batch_size"] = tune.grid_search([128, 256])
return config
[docs] def get_search_metric(self) -> Tuple[str, str]:
"""Returns the search metric for hyperparameter tuning
Returns:
Tuple[str, str]: (objective=['min', 'max'], objective metric)
"""
return "min", "total_c1_loss"
@classmethod
def parse_config(cls, config: Dict) -> CQLDiscrete:
return CQLDiscrete(
jax.random.PRNGKey(config["seed"]),
ActorDiscrete(config["actor/hidden_dim"], config["action_dims"]),
Critic(config["critic/hidden_dim"], 1),
config["state_dims"],
config["action_dims"],
config["gamma"],
config["tau"],
)
def train(
self,
config: Dict = None,
environment: str = "maze2d-open-v0",
rtune: tune = None,
enable_wandb: bool = True,
):
if config is None:
config = self.get_config_dict()
env = gym.make(environment)
env.reset()
dataset = d4rl.qlearning_dataset(env)
for iteration in tqdm(range(int(config["iterations"]))):
self.rng, _ = jax.random.split(self.rng)
batch_sample = sample(dataset, self.rng, config["batch_size"])
total_c1_loss, total_c2_loss, alpha_loss, actor_loss = self.step(
[
batch_sample["observations"],
batch_sample["actions"],
batch_sample["rewards"],
batch_sample["next_observations"],
batch_sample["terminals"],
]
)
iteration_metric = {
"total_c1_loss": total_c1_loss,
"total_c2_loss": total_c2_loss,
"alpha_loss": alpha_loss,
"actor_loss": actor_loss,
}
if enable_wandb:
wandb.log(iteration_metric)
if rtune:
session.report(iteration_metric)