offlax.models#

Classes

ActorContinuous(*args, **kwargs)

Actor for a continuous action space

ActorDiscrete(*args, **kwargs)

Actor for a continuous action space

Critic(*args, **kwargs)

Critic for an Actor-Critic based algorithm

Policy(*args, **kwargs)

Generic implementation of a Policy class

class offlax.models.ActorContinuous(*args, **kwargs)[source]#

Bases: flax.linen.

Actor for a continuous action space

Parameters
  • args (Any) –

  • kwargs (Any) –

Return type

Any

get_action(variables, state, key, deterministic=False, return_log_prob=True)[source]#

Returns action for a given state

Parameters
  • variables (VariableDict) – weights of the actor

  • state (jnp.ndarray) – state of the environment

  • key (jax.random.PRNGKey) – JAX random key

  • deterministic (bool, optional) – flag when True, return a deterministic action. Defaults to False.

  • return_log_prob (bool, optional) – flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True.

Returns

_description_ #TODO: Update this

Return type

jnp.ndarray

class offlax.models.ActorDiscrete(*args, **kwargs)[source]#

Bases: flax.linen.

Actor for a continuous action space

Parameters
  • args (Any) –

  • kwargs (Any) –

Return type

Any

get_action(variables, state, key, deterministic=False, return_log_prob=True)[source]#

Returns action for a given state

Parameters
  • variables (VariableDict) – weights of the actor

  • state (jnp.ndarray) – state of the environment

  • key (jax.random.PRNGKey) – JAX random key

  • deterministic (bool, optional) – flag when True, return a deterministic action. Defaults to False.

  • return_log_prob (bool, optional) – flag when True, returns the probability and the log of the probability of the actor actions. Defaults to True.

Returns

_description_ #TODO: Update this

Return type

Tuple[jnp.ndarray, jnp.ndarray]

class offlax.models.Critic(*args, **kwargs)[source]#

Bases: flax.linen.

Critic for an Actor-Critic based algorithm

Parameters
  • args (Any) –

  • kwargs (Any) –

Return type

Any

class offlax.models.Policy(*args, **kwargs)[source]#

Bases: flax.linen.

Generic implementation of a Policy class

Parameters
  • args (Any) –

  • kwargs (Any) –

Return type

Any