Update FairScale on CI (#7017)
* Try updating CI to latest fairscale
* Update availability of imports.py
* Remove some of the fairscale custom ci stuff
* Update grad scaler within the new process as reference is incorrect for spawn
* Remove fairscale from mocks
* Install fairscale 0.3.4 into the base container, remove from extra.txt
* Update docs/source/conf.py
* Fix import issues
* Mock fairscale for docs
* Fix DeepSpeed and FairScale to specific versions
* Swap back to greater than
* extras
* Revert "extras"
This reverts commit 7353479f
* ci
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
92af363270
commit
8439aead66
|
@ -83,10 +83,6 @@ jobs:
|
|||
req = open(fname).read().replace('>=', '==')
|
||||
open(fname, 'w').write(req)
|
||||
|
||||
# remove Fairscale from requirements
|
||||
fname = 'requirements/extra.txt'
|
||||
lines = [line for line in open(fname).readlines() if 'fairscale' not in line]
|
||||
open(fname, 'w').writelines(lines)
|
||||
shell: python
|
||||
|
||||
# Note: This uses an internal pip API and may not always work
|
||||
|
@ -131,8 +127,6 @@ jobs:
|
|||
pip --version
|
||||
# python -m pip install --upgrade --user pip
|
||||
pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
|
||||
# todo: drop fairscale til it is takem fro mainstream pip, the building take ages...
|
||||
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not 'fairscale' in line] ; open(fname, 'w').writelines(lines)"
|
||||
# adjust versions according installed Torch version
|
||||
python ./requirements/adjust_versions.py requirements/extra.txt
|
||||
python ./requirements/adjust_versions.py requirements/examples.txt
|
||||
|
|
|
@ -59,9 +59,9 @@ jobs:
|
|||
#sudo apt-get install -y cmake
|
||||
# python -m pip install "pip==20.1"
|
||||
pip install --requirement requirements.txt
|
||||
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)"
|
||||
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
|
||||
pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed
|
||||
pip install fairscale>=0.3.4 --upgrade-strategy only-if-needed
|
||||
pip list
|
||||
displayName: 'Install dependencies'
|
||||
|
||||
|
|
|
@ -112,6 +112,10 @@ RUN \
|
|||
# TODO: later commits break CI when cpp extensions are compiling. Unset when fixed
|
||||
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex@705cba9
|
||||
|
||||
RUN \
|
||||
# install FairScale
|
||||
pip install fairscale>=0.3.4
|
||||
|
||||
RUN \
|
||||
# install DeepSpeed
|
||||
# TODO(@SeanNaren): 0.3.15 is broken - skipping to unblock
|
||||
|
|
|
@ -17,6 +17,7 @@ import torch
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
|
@ -24,6 +25,7 @@ from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
|||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
|
||||
|
||||
|
@ -76,3 +78,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
|
||||
def post_training_step(self):
|
||||
pass
|
||||
|
||||
def new_process(self, process_idx, trainer, mp_queue):
|
||||
# Ensure that the scaler points to the correct process group
|
||||
# which is re-initialized in a new process
|
||||
precision_plugin = trainer.accelerator.precision_plugin
|
||||
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
precision_plugin.scaler = ShardedGradScaler()
|
||||
super().new_process(process_idx, trainer, mp_queue)
|
||||
|
|
|
@ -73,8 +73,8 @@ _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
|
|||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3")
|
||||
_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3")
|
||||
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
|
|
|
@ -7,6 +7,4 @@ torchtext>=0.5
|
|||
# onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs
|
||||
fairscale @https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip
|
||||
jsonargparse[signatures]>=3.9.0
|
||||
|
|
Loading…
Reference in New Issue