Skip to content

catx

CATXState

Holds the CATX's training state and extra parameterization for the networks.

CATX

This class runs CATX action sampling and training of the tree.

__init__(catx_network, optimizer, discretization_parameter, bandwidth, action_min=0.0, action_max=1.0)

Instantiate a CATX instance with its corresponding tree.

Parameters:

Name Type Description Default
catx_network Type[CATXHaikuNetwork]

class specifying the neural network architecture.

required
optimizer optax.GradientTransformation

optax optimizer object.

required
discretization_parameter int

the number of action centroids.

required
bandwidth float

the bucket half width covered by action centroid.

required
action_min float

the lowest value of the action space.

0.0
action_max float

the highest value of the action space.

1.0

sample(obs, epsilon, state)

Samples an action from the tree.

Parameters:

Name Type Description Default
obs Observations

the observations, i.e., batched contexts.

required
epsilon float

probability of selecting a random action.

required
state CATXState

holds the CATX's training state.

required

Returns:

Name Type Description
actions Actions

sampled actions from the tree using epsilon-greedy.

probabilities Probabilities

the probability density value of the sampled actions.

state CATXState

holds the CATX's training state.

learn(obs, actions, probabilities, costs, state)

Updates the tree
  • updates the parameters of the depth specific neural networks.
  • copy the parameters of the depth specific neural networks to the tree neural networks.
  • update the state of the optimizers.
  • update the pseudo random key generator.

Parameters:

Name Type Description Default
obs Observations

the observations, i.e., batched contexts.

required
actions Actions

The executed actions associated with the given observations.

required
probabilities Probabilities

the probability density value of the actions.

required
costs Costs

The costs incurred from the executed actions.

required
state CATXState

holds the CATX's training state.

required

Returns:

Name Type Description
state CATXState

holds the CATX's training state.

init(obs, key, epsilon, network_extras=None)

Initializes the parameters of tree's neural networks, the forward functions, and the optimizer states.

This function can only be called once. It is called the first time a CATX instance is used.

Parameters:

Name Type Description Default
obs JaxObservations

the observations, i.e., batched contexts.

required
key chex.PRNGKey

pseudo-random number generator.

required
epsilon float

probability of selecting a random action.

required
network_extras Optional[NetworkExtras]

additional information for querying the neural networks.

None

Returns:

Name Type Description
state CATXState

holds the CATX's training state.