diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 1c241866e5..80f43201cf 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -16,6 +16,7 @@ ARG CUDA_VERSION=10.2 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu18.04 +ARG BAGUA_CUDA_VERSION=102 ARG PYTHON_VERSION=3.9 ARG PYTORCH_VERSION=1.8 @@ -117,6 +118,10 @@ RUN \ pip install deepspeed==0.5.7 && \ python -c "import deepspeed; print(deepspeed.__version__)" +RUN \ + # install Bagua + pip install bagua-cuda${BAGUA_CUDA_VERSION}==0.9.0 + RUN \ # Show what we have pip --version && \