tree
TreeParameters
Dataclass for holding the tree parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bandwidth |
the bucket half width covered by action centroid. |
required | |
discretization_parameter |
the number of action centroids. |
required | |
action_space |
the range from which actions can be generated. |
required | |
depth |
number of layers in the tree. |
required | |
spaces |
an array indicating the start and end range of each action centroid. |
required | |
volumes |
an array indicating the bandwidth of around each action centroid. |
required | |
probabilities |
h-smoothing of policy π_t (one over volumes). |
required |
construct(discretization_parameter, bandwidth)
classmethod
A constructor for calculating and initializing the tree parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
discretization_parameter |
int
|
the number of action centroids. |
required |
bandwidth |
float
|
the bucket half width covered by action centroid. |
required |
Returns:
| Type | Description |
|---|---|
TreeParameters
|
An initialized TreeParameters dataclass. |
Tree
Bases: hk.Module
__init__(catx_network, tree_params, name=None)
The tree as a JAX Haiku module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
catx_network |
Type[CATXHaikuNetwork]
|
class specifying the neural network architecture. |
required |
tree_params |
TreeParameters
|
object holding the tree parameters. |
required |
name |
Optional[str]
|
name of the tree. |
None
|