From 8439aead663aee1028f520207446fb2e9d0165d3 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 23 Apr 2021 12:37:00 +0100 Subject: [PATCH] Update FairScale on CI (#7017) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Try updating CI to latest fairscale * Update availability of imports.py * Remove some of the fairscale custom ci stuff * Update grad scaler within the new process as reference is incorrect for spawn * Remove fairscale from mocks * Install fairscale 0.3.4 into the base container, remove from extra.txt * Update docs/source/conf.py * Fix import issues * Mock fairscale for docs * Fix DeepSpeed and FairScale to specific versions * Swap back to greater than * extras * Revert "extras" This reverts commit 7353479f * ci Co-authored-by: Carlos MocholĂ­ Co-authored-by: jirka --- .github/workflows/ci_test-full.yml | 6 ------ azure-pipelines.yml | 2 +- dockers/base-cuda/Dockerfile | 4 ++++ .../plugins/training_type/sharded_spawn.py | 10 ++++++++++ pytorch_lightning/utilities/imports.py | 4 ++-- requirements/extra.txt | 2 -- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 445a3db306..f6d08e0b4f 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -83,10 +83,6 @@ jobs: req = open(fname).read().replace('>=', '==') open(fname, 'w').write(req) - # remove Fairscale from requirements - fname = 'requirements/extra.txt' - lines = [line for line in open(fname).readlines() if 'fairscale' not in line] - open(fname, 'w').writelines(lines) shell: python # Note: This uses an internal pip API and may not always work @@ -131,8 +127,6 @@ jobs: pip --version # python -m pip install --upgrade --user pip pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade - # todo: drop fairscale til it is takem fro mainstream pip, the building take ages... - python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not 'fairscale' in line] ; open(fname, 'w').writelines(lines)" # adjust versions according installed Torch version python ./requirements/adjust_versions.py requirements/extra.txt python ./requirements/adjust_versions.py requirements/examples.txt diff --git a/azure-pipelines.yml b/azure-pipelines.yml index bd2b36be33..15832a6404 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -59,9 +59,9 @@ jobs: #sudo apt-get install -y cmake # python -m pip install "pip==20.1" pip install --requirement requirements.txt - python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed + pip install fairscale>=0.3.4 --upgrade-strategy only-if-needed pip list displayName: 'Install dependencies' diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 10446282c0..0d0947b97f 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -112,6 +112,10 @@ RUN \ # TODO: later commits break CI when cpp extensions are compiling. Unset when fixed pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex@705cba9 +RUN \ + # install FairScale + pip install fairscale>=0.3.4 + RUN \ # install DeepSpeed # TODO(@SeanNaren): 0.3.15 is broken - skipping to unblock diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 2dfe707047..a99d6ea481 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -17,6 +17,7 @@ import torch from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only @@ -24,6 +25,7 @@ from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded @@ -76,3 +78,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def post_training_step(self): pass + + def new_process(self, process_idx, trainer, mp_queue): + # Ensure that the scaler points to the correct process group + # which is re-initialized in a new process + precision_plugin = trainer.accelerator.precision_plugin + if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): + precision_plugin.scaler = ShardedGradScaler() + super().new_process(process_idx, trainer, mp_queue) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 1cd06c19ad..791cef7ff2 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -73,8 +73,8 @@ _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') -_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") +_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn') +_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") diff --git a/requirements/extra.txt b/requirements/extra.txt index a20db98319..e719ee3f30 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs -fairscale @https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip jsonargparse[signatures]>=3.9.0