catx
CATXState
Holds the CATX's training state and extra parameterization for the networks.
CATX
This class runs CATX action sampling and training of the tree.
__init__(catx_network, optimizer, discretization_parameter, bandwidth, action_min=0.0, action_max=1.0)
Instantiate a CATX instance with its corresponding tree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
catx_network |
Type[CATXHaikuNetwork]
|
class specifying the neural network architecture. |
required |
optimizer |
optax.GradientTransformation
|
optax optimizer object. |
required |
discretization_parameter |
int
|
the number of action centroids. |
required |
bandwidth |
float
|
the bucket half width covered by action centroid. |
required |
action_min |
float
|
the lowest value of the action space. |
0.0
|
action_max |
float
|
the highest value of the action space. |
1.0
|
sample(obs, epsilon, state)
Samples an action from the tree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs |
Observations
|
the observations, i.e., batched contexts. |
required |
epsilon |
float
|
probability of selecting a random action. |
required |
state |
CATXState
|
holds the CATX's training state. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
actions |
Actions
|
sampled actions from the tree using epsilon-greedy. |
probabilities |
Probabilities
|
the probability density value of the sampled actions. |
state |
CATXState
|
holds the CATX's training state. |
learn(obs, actions, probabilities, costs, state)
Updates the tree
- updates the parameters of the depth specific neural networks.
- copy the parameters of the depth specific neural networks to the tree neural networks.
- update the state of the optimizers.
- update the pseudo random key generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs |
Observations
|
the observations, i.e., batched contexts. |
required |
actions |
Actions
|
The executed actions associated with the given observations. |
required |
probabilities |
Probabilities
|
the probability density value of the actions. |
required |
costs |
Costs
|
The costs incurred from the executed actions. |
required |
state |
CATXState
|
holds the CATX's training state. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
CATXState
|
holds the CATX's training state. |
init(obs, key, epsilon, network_extras=None)
Initializes the parameters of tree's neural networks, the forward functions, and the optimizer states.
This function can only be called once. It is called the first time a CATX instance is used.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs |
JaxObservations
|
the observations, i.e., batched contexts. |
required |
key |
chex.PRNGKey
|
pseudo-random number generator. |
required |
epsilon |
float
|
probability of selecting a random action. |
required |
network_extras |
Optional[NetworkExtras]
|
additional information for querying the neural networks. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
state |
CATXState
|
holds the CATX's training state. |