Add auto wrapping for `DDPFullyShardedNativeStrategy` (#14252)

This commit is contained in:
Rohit Gupta 2022-08-26 14:31:48 +05:30 committed by GitHub
parent 70deac2cd4
commit 6d00f31f0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 33 deletions

View File

@ -212,14 +212,31 @@ PyTorch Fully Sharded Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_. The API is pretty similar to that of FairScale.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
Lightning supports. The API is pretty similar to that of FairScale.
.. note::
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.
To activate parameter sharding, you must wrap your model using the``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
Auto Wrapping
"""""""""""""
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.
.. code-block:: python
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
trainer.fit(model)
Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.
Manual Wrapping
"""""""""""""""
Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

View File

@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
- Added support for auto wrapping for `DDPFullyShardedNativeStrategy` ([#14252](https://github.com/Lightning-AI/lightning/issues/14252))
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))

View File

@ -25,14 +25,13 @@ if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
else:
MixedPrecision = None # type: ignore[misc,assignment]
ShardedGradScaler = None # type: ignore[misc,assignment]
class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Fully Sharded Native Training."""
def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards."

View File

@ -13,8 +13,6 @@
# limitations under the License.
from typing import Optional, Union
import torch
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
@ -29,9 +27,7 @@ else:
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Sharded Training."""
def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
"You have asked for sharded AMP but you have not installed it."

View File

@ -20,6 +20,7 @@ from torch import Tensor
from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
@ -38,9 +39,11 @@ from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
@ -51,6 +54,7 @@ if _TORCH_GREATER_EQUAL_1_12:
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]
@ -201,6 +205,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None
if (
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
and "auto_wrap_policy" in self.kwargs
):
del self.kwargs["auto_wrap_policy"]
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
)
def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
@ -215,9 +241,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
assert isinstance(self.model, pl.LightningModule)
self.model = _LightningModuleWrapperBase(self.model)
if is_overridden("configure_sharded_model", self.lightning_module):
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
)
else:
self.model = self._setup_model(self.model)
self.barrier()
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.setup_precision_plugin()
def model_to_device(self) -> None:
@ -273,6 +310,24 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# we don't need precision context since casting is done by FSDP
# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html
assert self.model is not None
return self.model(*args, **kwargs)
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
return self.model(*args, **kwargs)
def _determine_device_ids(self) -> List[int]:
return [self.root_device.index]

View File

@ -11,7 +11,6 @@ from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShar
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.runif import RunIf
if _TORCH_GREATER_EQUAL_1_12:
@ -19,6 +18,15 @@ if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.wrap import wrap
def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2
@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
@ -78,38 +86,73 @@ class TestFSDPModel(BoringModel):
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()
def on_test_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def on_validation_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
assert isinstance(self.layer.module[0], FullyShardedDataParallel)
assert isinstance(self.layer.module[2], FullyShardedDataParallel)
# root should not be resharding
assert self.layer.reshard_after_forward is False
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[0].reshard_after_forward is True
assert self.layer.module[2].reshard_after_forward is True
precision = torch.float16 if self.precision == 16 else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision
for layer_num in [0, 2]:
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[layer_num].reshard_after_forward is True
assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
class TestFSDPModelAutoWrapped(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
def configure_optimizers(self):
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()
def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
precision = torch.float16 if self.precision == 16 else torch.bfloat16
for layer_num in [0, 2]:
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer[layer_num].reshard_after_forward
assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
@ -140,18 +183,32 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
@pytest.mark.parametrize(
"model, strategy",
[
(TestFSDPModel(), "fsdp_native"),
(TestFSDPModelAutoWrapped(), DDPFullyShardedNativeStrategy),
],
)
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
model = TestFSDPModel()
ck = ModelCheckpoint(save_last=True)
if not isinstance(strategy, str):
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy="fsdp_native",
strategy=strategy,
precision=16,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
limit_predict_batches=2,
callbacks=[ck],
)
_run_multiple_stages(trainer, model)
@ -164,14 +221,20 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.save_checkpoint(model_path, weights_only=True)
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
_assert_save_equality(trainer, model_path, cls=model.__class__)
# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)
# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison