DeepSpeed Integration (#5954)

* Add initial deepspeed changes

* Address code review

* Move static method outside of function

* Fixes

* Add missing annotation

* Remove seed setting

* Doc changes

* Doc changes, add address reviews

* Fix docs

* Try fixing issue by moving to torch adam

* Clean up check

* Changes, better APIs!

* Add wrapper, swap to git install revision

* Add special test

* Add warning

* Address review

* Add better disclaimer

* Turn off ZeRO for testing due to compilation

* Add description on modifying parameters via the plugin

* Doc strings clear

* Small doc fixes

* Fix hash, reduce test

* Added CI change

* Move to azure pipeline

* Fix test name

* Add missing flag

* Remove sudo...

* Try conda instead

* Swap to conda base

* Try suggested install

* Apply suggestions from code review

* Apply suggestions from code review

* Revert "Apply suggestions from code review"

This reverts commit 41cca05a

* Revert "Apply suggestions from code review"

This reverts commit e06ec29e

* Remove setter

* Address most review

* Move out function, remove DeepSpeed from requirements

* Install deepspeed/mpi4py within container

* Use special tests, move to master commit for deepspeed

* Export path

* Force compile to happen first

* Remove!

* Debugging ninja

* Fix error in optimizer step logic

* Attempt to fix symbolic link

* Reverse to aid debugging

* Export path again

* Clean up mess

* var

* Revert "var"

This reverts commit 3450eaca

* Address review, add todo

* Add note about unsupported functionality

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
Sean Naren 2021-02-17 20:23:42 +00:00 committed by GitHub
parent 6a409c7f84
commit 7189d673f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 877 additions and 10 deletions

View File

@ -62,6 +62,11 @@ jobs:
pip list
displayName: 'Install dependencies'
- bash: |
# Temporary fix till DeepSpeed release, move this into CUDA image
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
displayName: 'Install DeepSpeed'
- script: |
python tests/collect_env_details.py
displayName: 'Env details'
@ -76,7 +81,9 @@ jobs:
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
displayName: 'Testing: standard'
- script: |
- bash: |
# Required for Ninja binary for building extensions, which is installed at this location
export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin
sh tests/special_tests.sh
displayName: 'Testing: special'

View File

@ -613,6 +613,8 @@ Lightning currently offers the following methods to leverage model parallelism:
- Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**)
- Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential <torch.nn.Sequential>` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization)
.. _sharded:
Sharded Training
^^^^^^^^^^^^^^^^
Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
@ -678,6 +680,149 @@ Sharded Training can work across all DDP variants by adding the additional ``--p
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
----------
.. _deep_speed:
DeepSpeed
^^^^^^^^^
.. note::
The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`_ offers additional CUDA deep learning training optimizations, similar to `FairScale <https://github.com/facebookresearch/fairscale>`_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark <https://github.com/huggingface/transformers/issues/9996>`_ and the DeepSpeed `docs <https://www.deepspeed.ai/tutorials/megatron/>`_.
We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers.
To use DeepSpeed, you first need to install DeepSpeed using the commands below.
.. code-block:: bash
pip install deepspeed mpi4py
If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``).
Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``.
.. note::
Currently ``resume_from_checkpoint`` and manual optimization are not supported.
DeepSpeed only supports single optimizer, single scheduler.
ZeRO-Offload
""""""""""""
Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.
For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload.
.. note::
To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_.
.. code-block:: python
from pytorch_lightning import Trainer
model = MyModel()
trainer = Trainer(gpus=4, plugins='deepspeed', precision=16)
trainer.fit(model)
This can also be done via the command line using a Pytorch Lightning script:
.. code-block:: bash
python train.py --plugins deepspeed --precision 16 --gpus 4
You can also modify the ZeRO-Offload parameters via the plugin as below.
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16)
trainer.fit(model)
.. note::
We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size.
These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed.
DeepSpeed allocates a reduce buffer size `multiplied by 4.5x <https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage2.py#L1594-L1607>`_ so take that into consideration when tweaking the parameters.
The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``.
Custom DeepSpeed Config
"""""""""""""""""""""""
DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
.. note::
All plugin default parameters will be ignored when a config object is passed.
All compatible arguments can be seen in the `DeepSpeed docs <https://www.deepspeed.ai/docs/config-json/>`_.
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
deepspeed_config = {
"zero_allow_untested_optimizer": True,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 3e-5,
"betas": [0.998, 0.999],
"eps": 1e-5,
"weight_decay": 1e-9,
"cuda_aware": True,
},
},
'scheduler': {
"type": "WarmupLR",
"params": {
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 100,
}
},
"zero_optimization": {
"stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
"cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU
"contiguous_gradients": True, # Reduce gradient fragmentation.
"overlap_comm": True, # Overlap reduce/backward operation of gradients for speed.
"allgather_bucket_size": 2e8, # Number of elements to all gather at once.
"reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once.
}
}
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16)
trainer.fit(model)
We support taking the config as a json formatted file:
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16)
trainer.fit(model)
You can use also use an environment variable via your PyTorch Lightning script:
.. code-block:: bash
PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed
----------
.. _sequential-parallelism:

View File

@ -284,7 +284,7 @@ class Accelerator(object):
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients"""
@ -315,9 +315,11 @@ class Accelerator(object):
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing is True:
if trainer.testing:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

View File

@ -1,5 +1,6 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
@ -7,6 +8,7 @@ from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugi
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
@ -25,6 +27,8 @@ __all__ = [
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",

View File

@ -1,4 +1,5 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401

View File

@ -0,0 +1,61 @@
from typing import Callable, Union
import torch
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
class DeepSpeedPrecisionPlugin(PrecisionPlugin):
def __init__(self, precision):
super().__init__()
self.precision = precision
def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
lambda_closure()
if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")
deepspeed_engine.step()
return False
def backward(
self,
lightning_module: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
if is_overridden('backward', lightning_module):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = lightning_module.trainer.model
deepspeed_engine.backward(closure_loss, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
pass

View File

@ -1,6 +1,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin

View File

@ -0,0 +1,323 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
if _DEEPSPEED_AVAILABLE:
import deepspeed
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: LightningModule, precision: int):
super().__init__(pl_module)
self.precision = precision
def forward(self, *inputs, **kwargs):
if self.precision == 16:
inputs = self._move_float_tensors_to_half(inputs)
return super().forward(*inputs, **kwargs)
@staticmethod
def batch_to(data):
return data.half()
def _move_float_tensors_to_half(self, batch: Any):
batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch
class DeepSpeedPlugin(DDPPlugin):
distributed_backend = "deepspeed"
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
def __init__(
self,
zero_optimization: bool = True,
stage: int = 2,
cpu_offload: bool = True,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
reduce_scatter: bool = True,
allgather_bucket_size: int = 2e8,
reduce_bucket_size: int = 2e8,
zero_allow_untested_optimizer: bool = True,
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
num_nodes: int = 1,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
) -> None:
"""
Provides capabilities to run training using the DeepSpeed library,
with training optimizations for large billion parameter models.
`For more information: https://www.deepspeed.ai/`.
.. warning:: ``DeepSpeedPlugin`` is in beta and subject to change.
Defaults have been set to enable ZeRO-Offload and some have been taken from the link below.
These defaults have been set generally, but may require tuning for optimum performance based on your model size.
`For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`.
Arguments:
zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True)
stage: Different stages of the ZeRO Optimizer. 0 is disabled,
1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2)
cpu_offload: Enable offloading optimizer memory and computation to CPU (default: True)
contiguous_gradients: Copies gradients to a continuous buffer as they are produced.
Avoids memory fragmentation during backwards. Useful when training large models. (default: True)
overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation.
This is a speed optimization when training across multiple GPUs/machines. (default: True)
allgather_partitions: All gather updated parameters at the end of training step,
instead of using a series of broadcast collectives (default: True)
reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True)
allgather_bucket_size: Number of elements to allgather at once.
Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8)
reduce_bucket_size: Number of elements to reduce at once.
Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8)
zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
DeepSpeed supported optimizer when using ZeRO (default: True)
config: Pass in a deepspeed formatted config dict,
or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
All defaults will be ignored if a config is passed in. (Default: ``None``)
logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``)
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
"To use the DeepSpeed plugin, you must have DeepSpeed installed."
" pip install deepspeed mpi4py"
)
super().__init__(
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
)
self.config = self._load_config(config)
if self.config is None:
# User has not overridden config, set defaults
self.config = self._create_default_config(
zero_optimization,
zero_allow_untested_optimizer,
stage=stage,
cpu_offload=cpu_offload,
contiguous_gradients=contiguous_gradients,
overlap_comm=overlap_comm,
allgather_partitions=allgather_partitions,
reduce_scatter=reduce_scatter,
allgather_bucket_size=allgather_bucket_size,
reduce_bucket_size=reduce_bucket_size
)
self._config_initialized = False
deepspeed.utils.logging.logger.setLevel(logging_level)
def _load_config(self, config):
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
config = os.environ[self.DEEPSPEED_ENV_VAR]
if isinstance(config, str) or isinstance(config, Path):
if os.path.exists(config):
with open(config) as f:
config = json.load(f)
else:
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
return config
def pre_dispatch(self):
self.set_world_ranks()
self.init_ddp_connection(self.global_rank, self.world_size)
self.init_deepspeed()
# set warning rank
rank_zero_only.rank = self.global_rank
# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
self.barrier()
def init_deepspeed(self):
if not self._config_initialized:
self._format_config()
self._config_initialized = True
precision = self.lightning_module.trainer.accelerator_backend.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
if self.lightning_module.trainer.training:
self._initialize_deepspeed_train(model)
else:
self._initialize_deepspeed_inference(model)
def _init_scheduler_optimizer(self):
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
if (len(optimizers) != 1) or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
)
scheduler = schedulers[0]['scheduler'] if len(schedulers) == 1 else None
optimizer = optimizers[0]
return optimizer, scheduler, optimizer_frequencies
def _initialize_deepspeed_train(self, model):
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
if "optimizer" not in self.config:
rank_zero_info(
"You have not specified an optimizer or scheduler within the DeepSpeed config."
"Using `configure_optimizers` to define optimizer and scheduler."
)
optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer()
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(local_rank=self.local_rank),
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=lightning_scheduler,
config_params=self.config,
)
# set optimizer for save/load, but deepspeed manages the specific optimizer logic
trainer = self.lightning_module.trainer
trainer.optimizers = [optimizer]
self.model = model
def _initialize_deepspeed_inference(self, model):
# move the model to the correct device
self.model_to_device()
self.pre_configure_ddp()
self._model = DistributedDataParallel(
model,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
def configure_scheduler(self, lr_scheduler):
# this duplicates the defaults from init_optimizers
scheduler = {
'scheduler': lr_scheduler,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
return [scheduler]
@property
def lightning_module(self):
# the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early
module = getattr(self.model, "module", self.model)
return module.module if isinstance(module, LightningDeepSpeedModule) else module
@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs
def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
return [], [], [] # empty optimizers, schedulers and frequencies
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
# note: We rely on the deepspeed engine to carry out the step rather than the optimizer.
# internally, the engine has a reference to the optimizer already.
self.model.step(**kwargs)
def _format_config(self):
if self.config is None:
raise MisconfigurationException(
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
" See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed"
)
self._format_batch_size_and_grad_accum_config()
self._format_precision_config()
def _format_batch_size_and_grad_accum_config(self):
if "gradient_accumulation_steps" in self.config:
raise MisconfigurationException(
"Within the DeepSpeed config, do not set gradient_accumulation_steps"
" as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
)
if "train_micro_batch_size_per_gpu" not in self.config:
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
batch_size = self.lightning_module.train_dataloader().batch_size
self.config["train_micro_batch_size_per_gpu"] = batch_size
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
def _format_precision_config(self):
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
precision = self.lightning_module.trainer.accelerator_connector.precision
if precision == 16:
if "amp" not in self.config and amp_type == AMPType.NATIVE:
self.config["fp16"] = {"enabled": True}
elif "apex" not in self.config and amp_type == AMPType.APEX:
self.config["amp"] = {
"enabled": True,
"opt_level": amp_level,
}
if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config):
raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.")
def _create_default_config(
self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs
) -> Dict:
if zero_optimization:
return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs}
return {}

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
import torch
from torch.nn import Module
@ -152,3 +152,9 @@ class TrainingTypePlugin(Plugin, ABC):
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return dataloader
def init_optimizers(self, trainer: "Trainer", model: LightningModule):
return trainer.init_optimizers(model)
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)

View File

@ -30,8 +30,11 @@ from pytorch_lightning.plugins import (
DDPShardedPlugin,
DDPSpawnPlugin,
DDPSpawnShardedPlugin,
DeepSpeedPlugin,
DeepSpeedPrecisionPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
Plugin,
PrecisionPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
@ -144,7 +147,7 @@ class AcceleratorConnector(object):
self.replace_sampler_ddp = replace_sampler_ddp
def handle_given_plugins(self, plugins: Optional[Sequence]):
def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]):
plugins = plugins if plugins is not None else []
if isinstance(plugins, str):
@ -243,7 +246,7 @@ class AcceleratorConnector(object):
def use_ddp(self) -> bool:
return self._distrib_type in (
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
DistributedType.DDP_SHARDED_SPAWN
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
)
@property
@ -254,6 +257,10 @@ class AcceleratorConnector(object):
def use_horovod(self) -> bool:
return self._distrib_type == DistributedType.HOROVOD
@property
def use_deepspeed(self) -> bool:
return self._distrib_type == DistributedType.DEEPSPEED
@property
def is_distributed(self) -> bool:
is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod
@ -290,15 +297,19 @@ class AcceleratorConnector(object):
return te_flags_passed
def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
self.amp_type = AMPType.from_str(self.amp_type)
if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
return DeepSpeedPrecisionPlugin(self.precision)
if self.precision == 32:
self.amp_type = None
return PrecisionPlugin()
elif self.precision == 16:
if self.on_tpu:
return TPUHalfPrecisionPlugin()
self.amp_type = AMPType(self.amp_type)
if self.amp_type == AMPType.NATIVE:
if self.on_cpu:
raise MisconfigurationException(
@ -338,6 +349,12 @@ class AcceleratorConnector(object):
def select_training_type_plugin(self) -> TrainingTypePlugin:
if self.use_ddp2:
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
elif self.use_ddp and self.use_deepspeed:
plugin = DeepSpeedPlugin(
num_nodes=self.num_nodes,
cluster_environment=self.select_cluster_environment(),
parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic

View File

@ -29,6 +29,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
@ -130,7 +131,7 @@ class Trainer(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
plugins: Optional[Union[str, list]] = None,
plugins: Optional[Union[Plugin, str, list]] = None,
amp_backend: str = 'native',
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,

View File

@ -26,6 +26,7 @@ from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedTy
from pytorch_lightning.utilities.imports import ( # noqa: F401
_APEX_AVAILABLE,
_BOLTS_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_PIPE_AVAILABLE,
_GROUP_AVAILABLE,

View File

@ -62,6 +62,7 @@ class DistributedType(LightningEnum):
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
DEEPSPEED = 'deepspeed'
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
DDP_SHARDED_SPAWN = 'ddp_sharded_spawn'

View File

@ -55,6 +55,7 @@ _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized')
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')

View File

@ -0,0 +1,292 @@
import json
import os
import pytest
import torch
from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
@pytest.fixture
def deepspeed_config():
return {
"optimizer": {
"type": "SGD",
"params": {
"lr": 3e-5,
},
},
'scheduler': {
"type": "WarmupLR",
"params": {
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 100,
}
}
}
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_plugin_string(tmpdir):
"""
Test to ensure that the plugin can be passed via string, and parallel devices is correctly set.
"""
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins='deepspeed',
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')]
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_plugin(tmpdir):
"""
Test to ensure that the plugin can be passed directly, and parallel devices is correctly set.
"""
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin()],
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')]
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config):
"""
Test to ensure that the plugin can be passed via a string with an environment variable.
"""
config_path = os.path.join(tmpdir, 'temp.json')
with open(config_path, 'w') as f:
f.write(json.dumps(deepspeed_config))
monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path)
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins='deepspeed',
)
plugin = trainer.accelerator_backend.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
assert plugin.parallel_devices == [torch.device('cpu')]
assert plugin.config == deepspeed_config
@pytest.mark.parametrize(
"amp_backend", [
pytest.param("native", marks=pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")),
pytest.param("apex", marks=pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex")),
]
)
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
def test_deepspeed_precision_choice(amp_backend, tmpdir):
"""
Test to ensure precision plugin is also correctly chosen.
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
"""
trainer = Trainer(
fast_dev_run=True, default_root_dir=tmpdir, plugins='deepspeed', amp_backend=amp_backend, precision=16
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin)
assert trainer.accelerator_backend.precision_plugin.precision == 16
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_with_invalid_config_path(tmpdir):
"""
Test to ensure if we pass an invalid config path we throw an exception.
"""
with pytest.raises(
MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist"
):
DeepSpeedPlugin(config='invalid_path.json')
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config):
"""
Test to ensure if we pass an env variable, we load the config from the path.
"""
config_path = os.path.join(tmpdir, 'temp.json')
with open(config_path, 'w') as f:
f.write(json.dumps(deepspeed_config))
monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path)
plugin = DeepSpeedPlugin()
assert plugin.config == deepspeed_config
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_defaults(tmpdir):
"""
Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed.
"""
plugin = DeepSpeedPlugin()
assert plugin.config is not None
assert isinstance(plugin.config["zero_optimization"], dict)
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_invalid_deepspeed_defaults_no_precision(tmpdir):
"""
Test to ensure that using defaults, if precision is not set to 16, we throw an exception.
"""
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins='deepspeed',
)
with pytest.raises(
MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.'
):
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_warn_deepspeed_override_backward(tmpdir):
"""
Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.
"""
class TestModel(BoringModel):
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
return loss.backward()
model = TestModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins=DeepSpeedPlugin(zero_optimization=False),
gpus=1,
)
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_deepspeed_run_configure_optimizers(tmpdir):
"""
Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation),
whilst using configure_optimizers for optimizers and schedulers.
"""
class TestModel(BoringModel):
def on_train_start(self) -> None:
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR)
model = TestModel()
trainer = Trainer(
plugins=DeepSpeedPlugin(zero_optimization=False),
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=True,
)
trainer.fit(model)
_assert_save_model_is_equal(model, tmpdir, trainer)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_deepspeed_config(tmpdir, deepspeed_config):
"""
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
and saves the model weights to load correctly.
"""
class TestModel(BoringModel):
def on_train_start(self) -> None:
import deepspeed
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
assert isinstance(self.trainer.model.optimizer, torch.optim.SGD)
assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR)
model = TestModel()
trainer = Trainer(
plugins=[DeepSpeedPlugin(config=deepspeed_config)],
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=True,
)
trainer.fit(model)
trainer.test(model)
_assert_save_model_is_equal(model, tmpdir, trainer)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_deepspeed_multigpu(tmpdir, deepspeed_config):
"""
Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation.
"""
model = BoringModel()
trainer = Trainer(
plugins=[DeepSpeedPlugin(zero_optimization=False)],
default_root_dir=tmpdir,
gpus=2,
fast_dev_run=True,
precision=16,
)
trainer.fit(model)
trainer.test(model)
_assert_save_model_is_equal(model, tmpdir, trainer)
def _assert_save_model_is_equal(model, tmpdir, trainer):
checkpoint_path = os.path.join(tmpdir, 'model.pt')
trainer.save_checkpoint(checkpoint_path)
# carry out the check only on rank 0
if trainer.global_rank == 0:
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
saved_model = saved_model.float()
model = model.float().cpu()
# Assert model parameters are identical after loading
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(orig_param, trained_model_param)

View File

@ -17,6 +17,10 @@ export PL_RUNNING_SPECIAL_TESTS=1
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp