pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch