Update FairScale on CI (#7017)

* Try updating CI to latest fairscale

* Update availability of imports.py

* Remove some of the fairscale custom ci stuff

* Update grad scaler within the new process as reference is incorrect for spawn

* Remove fairscale from mocks

* Install fairscale 0.3.4 into the base container, remove from extra.txt

* Update docs/source/conf.py

* Fix import issues

* Mock fairscale for docs

* Fix DeepSpeed and FairScale to specific versions

* Swap back to greater than

* extras

* Revert "extras"

This reverts commit 7353479f

* ci

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Sean Naren 2021-04-23 12:37:00 +01:00 committed by GitHub
parent 92af363270
commit 8439aead66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 11 deletions

View File

@ -83,10 +83,6 @@ jobs:
req = open(fname).read().replace('>=', '==')
open(fname, 'w').write(req)
# remove Fairscale from requirements
fname = 'requirements/extra.txt'
lines = [line for line in open(fname).readlines() if 'fairscale' not in line]
open(fname, 'w').writelines(lines)
shell: python
# Note: This uses an internal pip API and may not always work
@ -131,8 +127,6 @@ jobs:
pip --version
# python -m pip install --upgrade --user pip
pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
# todo: drop fairscale til it is takem fro mainstream pip, the building take ages...
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not 'fairscale' in line] ; open(fname, 'w').writelines(lines)"
# adjust versions according installed Torch version
python ./requirements/adjust_versions.py requirements/extra.txt
python ./requirements/adjust_versions.py requirements/examples.txt

View File

@ -59,9 +59,9 @@ jobs:
#sudo apt-get install -y cmake
# python -m pip install "pip==20.1"
pip install --requirement requirements.txt
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed
pip install fairscale>=0.3.4 --upgrade-strategy only-if-needed
pip list
displayName: 'Install dependencies'

View File

@ -112,6 +112,10 @@ RUN \
# TODO: later commits break CI when cpp extensions are compiling. Unset when fixed
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex@705cba9
RUN \
# install FairScale
pip install fairscale>=0.3.4
RUN \
# install DeepSpeed
# TODO(@SeanNaren): 0.3.15 is broken - skipping to unblock

View File

@ -17,6 +17,7 @@ import torch
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
@ -24,6 +25,7 @@ from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
@ -76,3 +78,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def post_training_step(self):
pass
def new_process(self, process_idx, trainer, mp_queue):
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
precision_plugin = trainer.accelerator.precision_plugin
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
precision_plugin.scaler = ShardedGradScaler()
super().new_process(process_idx, trainer, mp_queue)

View File

@ -73,8 +73,8 @@ _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3")
_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")

View File

@ -7,6 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs
fairscale @https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip
jsonargparse[signatures]>=3.9.0