offlax.models
offlax.models#
Classes
|
Actor for a continuous action space |
|
Actor for a continuous action space |
|
Critic for an Actor-Critic based algorithm |
|
Generic implementation of a Policy class |
- class offlax.models.ActorContinuous(*args, **kwargs)[source]#
Bases:
flax.linen.Actor for a continuous action space
- Parameters
args (Any) –
kwargs (Any) –
- Return type
Any
- get_action(variables, state, key, deterministic=False, return_log_prob=True)[source]#
Returns action for a given state
- Parameters
variables (VariableDict) – weights of the actor
state (jnp.ndarray) – state of the environment
key (jax.random.PRNGKey) – JAX random key
deterministic (bool, optional) – flag when True, return a deterministic action. Defaults to False.
return_log_prob (bool, optional) – flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True.
- Returns
_description_ #TODO: Update this
- Return type
jnp.ndarray
- class offlax.models.ActorDiscrete(*args, **kwargs)[source]#
Bases:
flax.linen.Actor for a continuous action space
- Parameters
args (Any) –
kwargs (Any) –
- Return type
Any
- get_action(variables, state, key, deterministic=False, return_log_prob=True)[source]#
Returns action for a given state
- Parameters
variables (VariableDict) – weights of the actor
state (jnp.ndarray) – state of the environment
key (jax.random.PRNGKey) – JAX random key
deterministic (bool, optional) – flag when True, return a deterministic action. Defaults to False.
return_log_prob (bool, optional) – flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True.
- Returns
_description_ #TODO: Update this
- Return type
Tuple[jnp.ndarray, jnp.ndarray]