Installation
step 1: CATX installation
The simplest way to install CATX is through PyPI:
pip install catx
We recommend using either Docker, Singularity, or conda to use the repository.
step 2 [Optional]: JAX GPU installation
CATX installation in step 1 uses JAX on CPU.
To unleash the full speed of CATX and have access to a GPU, a GPU version of JAX must be installed.
JAX installation is different depending on your CUDA version. Follow these instructions to install JAX with the relevant accelerator support.
TL;DR:
run pip install --upgrade pip then run one of the following depending on your machine:
-
No GPU (not needed as JAX cpu was installed in step 1):
pip install --upgrade "jax[cpu]" -
GPU with CUDA 11 and cuDNN 8.2 or newer:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -
GPU with Cuda >= 11.1 and cudnn >= 8.0.5:
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Installation for development:
-
Clone your fork of this repo:
git clone git@github.com:instadeepai/catx.git -
Go to directory:
cd catx -
Add upstream repo
git remote add upstream https://github.com/instadeepai/catx.git -
Create your venv or conda environment and install dependencies:
conda:
conda env create -f environment.yaml conda activate catxvenv:
python3 -m venv pip install -e .[tool,test] -
[Optional] follow step 2 above for JAX to use GPU.