Skip to content

tree

TreeParameters

Dataclass for holding the tree parameters.

Parameters:

Name Type Description Default
bandwidth

the bucket half width covered by action centroid.

required
discretization_parameter

the number of action centroids.

required
action_space

the range from which actions can be generated.

required
depth

number of layers in the tree.

required
spaces

an array indicating the start and end range of each action centroid.

required
volumes

an array indicating the bandwidth of around each action centroid.

required
probabilities

h-smoothing of policy π_t (one over volumes).

required

construct(discretization_parameter, bandwidth) classmethod

A constructor for calculating and initializing the tree parameters.

Parameters:

Name Type Description Default
discretization_parameter int

the number of action centroids.

required
bandwidth float

the bucket half width covered by action centroid.

required

Returns:

Type Description
TreeParameters

An initialized TreeParameters dataclass.

Tree

Bases: hk.Module

__init__(catx_network, tree_params, name=None)

The tree as a JAX Haiku module.

Parameters:

Name Type Description Default
catx_network Type[CATXHaikuNetwork]

class specifying the neural network architecture.

required
tree_params TreeParameters

object holding the tree parameters.

required
name Optional[str]

name of the tree.

None