Getting started:
This example uses the black Friday dataset from OpenML.
The task is to predict how much an individual will purchase, i.e., continuous action, based on the feature of that individual, i.e., context.
Environment class
For simplicity, we normalize the feature and action spaces.
from typing import Optional
import numpy as np
import tensorflow as tf
from sklearn.datasets import fetch_openml
from catx.type_defs import Actions, Costs, Observations
class BlackFridayEnvironment:
def __init__(self, batch_size: int = 10) -> None:
self.x, self.y = fetch_openml(data_id=41540, as_frame=False, return_X_y=True)
rows_with_nan_idx = np.argwhere(np.isnan(self.x))[:, 0]
self.x = np.delete(self.x, rows_with_nan_idx, axis=0)
self.y = np.delete(self.y, rows_with_nan_idx, axis=0)
self.x = self._normalize_data(self.x)
self.y = self._normalize_data(self.y)
physical_devices = tf.config.list_physical_devices("GPU")
try:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except Exception:
pass
self.dataset = tf.data.Dataset.from_tensor_slices((self.x, self.y))
self.dataset = self.dataset.batch(batch_size)
self.iterator = iter(self.dataset)
def get_new_observations(self) -> Optional[Observations]:
try:
x, y = self.iterator.get_next()
self.x = x.numpy()
self.y = y.numpy()
return self.x
except tf.errors.OutOfRangeError:
return None
def get_costs(self, actions: Actions) -> Costs:
costs = np.abs(actions - self.y)
return costs
def _normalize_data(self, data: np.ndarray) -> np.ndarray:
return (data - np.min(data, axis=0)) / (
np.max(data, axis=0) - np.min(data, axis=0)
)
Training loop
One of the main advantages of CATX is the flexibility of defining a custom neural network architecture within the tree.
The custom neural network must be a JAX/Haiku network which inherits from CATXHaikuNetwork. In this example, we use a multilayer perceptron (MLP) network with dropouts that are activated during the learning step.
IMPORTANT: The number of neurons at the output layer should be 2**(depth+1)
# CATX imports
import time
from typing import List
import haiku as hk
import jax
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import numpy as jnp
from catx.catx import CATX
from catx.network_module import CATXHaikuNetwork
from catx.type_defs import Observations, NetworkExtras, Logits
# Network builder
class MyCATXNetwork(CATXHaikuNetwork):
def __init__(self, depth: int) -> None:
super().__init__(depth)
self.network = hk.nets.MLP(
[10, 10] + [2 ** (self.depth + 1)], name=f"mlp_depth_{self.depth}"
)
def __call__(
self,
obs: Observations,
network_extras: NetworkExtras,
) -> Logits:
return self.network(
obs, dropout_rate=network_extras["dropout_rate"], rng=hk.next_rng_key()
)
def moving_average(x: List[float], w: int) -> np.ndarray:
return np.convolve(x, np.ones(w), "valid") / w
def main() -> None:
start_time = time.time()
epsilon = 0.05
# JAX pseudo-random number generator
rng_key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(rng_key)
# Instantiate the environment
environment = BlackFridayEnvironment()
# Instantiate CATX
catx = CATX(
catx_network=MyCATXNetwork,
optimizer=optax.adam(learning_rate=0.01),
discretization_parameter=8,
bandwidth=1 / 8,
)
# Training loop
costs_cumulative = []
for i in range(1000):
obs = environment.get_new_observations()
if obs is None:
break
if i == 0:
network_extras = {"dropout_rate": 0.0}
state = catx.init(
obs=obs, epsilon=epsilon, key=key, network_extras=network_extras
)
state.network_extras["dropout_rate"] = 0.0
actions, probabilities, state = catx.sample(
obs=obs, epsilon=epsilon, state=state
)
costs = environment.get_costs(actions=actions)
state.network_extras["dropout_rate"] = 0.2
state = catx.learn(
obs=obs,
actions=actions,
probabilities=probabilities,
costs=costs,
state=state,
)
costs_cumulative.append(jnp.mean(costs).item())
plt.plot(costs_cumulative)
plt.plot(moving_average(costs_cumulative, 50))
plt.title("Action costs")
plt.show()
print(f"CATX training took {time.time() - start_time:.1f}s")
if __name__ == "__main__":
main()