Add auto wrapping for `DDPFullyShardedNativeStrategy` (#14252)
This commit is contained in:
parent
70deac2cd4
commit
6d00f31f0c
|
@ -212,14 +212,31 @@ PyTorch Fully Sharded Training
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
|
||||
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_. The API is pretty similar to that of FairScale.
|
||||
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
|
||||
Lightning supports. The API is pretty similar to that of FairScale.
|
||||
|
||||
.. note::
|
||||
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
|
||||
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
|
||||
This is a limitation of Fully Sharded Training that will be resolved in the future.
|
||||
|
||||
To activate parameter sharding, you must wrap your model using the``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
|
||||
Auto Wrapping
|
||||
"""""""""""""
|
||||
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
|
||||
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
|
||||
have to ``wrap`` layers manually as in the case of manual wrapping.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.
|
||||
|
||||
|
||||
Manual Wrapping
|
||||
"""""""""""""""
|
||||
|
||||
Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
|
||||
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
|
||||
|
||||
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
|
||||
|
||||
|
||||
- Added support for auto wrapping for `DDPFullyShardedNativeStrategy` ([#14252](https://github.com/Lightning-AI/lightning/issues/14252))
|
||||
|
||||
|
||||
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))
|
||||
|
||||
|
||||
|
|
|
@ -25,14 +25,13 @@ if _TORCH_GREATER_EQUAL_1_12:
|
|||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
else:
|
||||
MixedPrecision = None # type: ignore[misc,assignment]
|
||||
ShardedGradScaler = None # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|
||||
"""Native AMP for Fully Sharded Native Training."""
|
||||
|
||||
def __init__(
|
||||
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
|
||||
) -> None:
|
||||
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
raise MisconfigurationException(
|
||||
"`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards."
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
|
||||
|
@ -29,9 +27,7 @@ else:
|
|||
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|
||||
"""Native AMP for Sharded Training."""
|
||||
|
||||
def __init__(
|
||||
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
|
||||
) -> None:
|
||||
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
|
||||
if not _FAIRSCALE_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
"You have asked for sharded AMP but you have not installed it."
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch import Tensor
|
|||
from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
|
@ -38,9 +39,11 @@ from pytorch_lightning.utilities.distributed import group as _group
|
|||
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
|
@ -51,6 +54,7 @@ if _TORCH_GREATER_EQUAL_1_12:
|
|||
)
|
||||
from torch.distributed.fsdp.wrap import enable_wrap
|
||||
else:
|
||||
FullyShardedDataParallel = None # type: ignore[misc,assignment]
|
||||
MixedPrecision = None # type: ignore[misc,assignment]
|
||||
BackwardPrefetch = None # type: ignore[misc,assignment]
|
||||
CPUOffload = None # type: ignore[misc,assignment]
|
||||
|
@ -201,6 +205,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
|
||||
self._rank_0_will_call_children_scripts = True
|
||||
|
||||
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
|
||||
"""Wraps the model into a
|
||||
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
||||
assert self.lightning_module is not None
|
||||
if (
|
||||
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
|
||||
and "auto_wrap_policy" in self.kwargs
|
||||
):
|
||||
del self.kwargs["auto_wrap_policy"]
|
||||
|
||||
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
|
||||
return FullyShardedDataParallel(
|
||||
module=model,
|
||||
process_group=self.process_group,
|
||||
cpu_offload=self.cpu_offload,
|
||||
backward_prefetch=self.backward_prefetch,
|
||||
mixed_precision=self.mixed_precision_config,
|
||||
device_id=self.root_device.index,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
@ -215,9 +241,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
assert self.lightning_module is not None
|
||||
self.lightning_module._device = self.root_device
|
||||
|
||||
assert isinstance(self.model, pl.LightningModule)
|
||||
self.model = _LightningModuleWrapperBase(self.model)
|
||||
if is_overridden("configure_sharded_model", self.lightning_module):
|
||||
rank_zero_info(
|
||||
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
|
||||
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
|
||||
)
|
||||
else:
|
||||
self.model = self._setup_model(self.model)
|
||||
self.barrier()
|
||||
|
||||
self.setup_optimizers(trainer)
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
||||
self.setup_precision_plugin()
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
|
@ -273,6 +310,24 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
|
||||
return tensor
|
||||
|
||||
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
||||
# we don't need precision context since casting is done by FSDP
|
||||
# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html
|
||||
assert self.model is not None
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
assert self.model is not None
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
assert self.model is not None
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
||||
assert self.model is not None
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def _determine_device_ids(self) -> List[int]:
|
||||
return [self.root_device.index]
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShar
|
|||
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
|
@ -19,6 +18,15 @@ if _TORCH_GREATER_EQUAL_1_12:
|
|||
from torch.distributed.fsdp.wrap import wrap
|
||||
|
||||
|
||||
def custom_auto_wrap_policy(
|
||||
module,
|
||||
recurse,
|
||||
unwrapped_params: int,
|
||||
min_num_params: int = int(1e8),
|
||||
) -> bool:
|
||||
return unwrapped_params >= 2
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
|
||||
|
@ -78,38 +86,73 @@ class TestFSDPModel(BoringModel):
|
|||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
|
||||
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_test_batch_end(
|
||||
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_validation_batch_end(
|
||||
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def _assert_layer_fsdp_instance(self) -> None:
|
||||
assert isinstance(self.layer, FullyShardedDataParallel)
|
||||
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
|
||||
assert isinstance(self.layer.module[0], FullyShardedDataParallel)
|
||||
assert isinstance(self.layer.module[2], FullyShardedDataParallel)
|
||||
# root should not be resharding
|
||||
assert self.layer.reshard_after_forward is False
|
||||
# Assert that the nested layers are set reshard_after_forward to True
|
||||
assert self.layer.module[0].reshard_after_forward is True
|
||||
assert self.layer.module[2].reshard_after_forward is True
|
||||
|
||||
precision = torch.float16 if self.precision == 16 else torch.bfloat16
|
||||
assert self.layer.mixed_precision.param_dtype == precision
|
||||
assert self.layer.mixed_precision.reduce_dtype == precision
|
||||
assert self.layer.mixed_precision.buffer_dtype == precision
|
||||
|
||||
for layer_num in [0, 2]:
|
||||
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
|
||||
# Assert that the nested layers are set reshard_after_forward to True
|
||||
assert self.layer.module[layer_num].reshard_after_forward is True
|
||||
|
||||
assert self.layer[layer_num].mixed_precision.param_dtype == precision
|
||||
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
|
||||
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
|
||||
|
||||
|
||||
class TestFSDPModelAutoWrapped(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def _assert_layer_fsdp_instance(self) -> None:
|
||||
assert isinstance(self.layer, torch.nn.Sequential)
|
||||
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
|
||||
|
||||
precision = torch.float16 if self.precision == 16 else torch.bfloat16
|
||||
for layer_num in [0, 2]:
|
||||
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
|
||||
# Assert that the nested layers are set reshard_after_forward to True
|
||||
assert self.layer[layer_num].reshard_after_forward
|
||||
|
||||
assert self.layer[layer_num].mixed_precision.param_dtype == precision
|
||||
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
|
||||
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
|
||||
|
@ -140,18 +183,32 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):
|
|||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
|
||||
@pytest.mark.parametrize(
|
||||
"model, strategy",
|
||||
[
|
||||
(TestFSDPModel(), "fsdp_native"),
|
||||
(TestFSDPModelAutoWrapped(), DDPFullyShardedNativeStrategy),
|
||||
],
|
||||
)
|
||||
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
|
||||
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
|
||||
|
||||
model = TestFSDPModel()
|
||||
ck = ModelCheckpoint(save_last=True)
|
||||
|
||||
if not isinstance(strategy, str):
|
||||
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
accelerator="gpu",
|
||||
devices=2,
|
||||
strategy="fsdp_native",
|
||||
strategy=strategy,
|
||||
precision=16,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
limit_predict_batches=2,
|
||||
callbacks=[ck],
|
||||
)
|
||||
_run_multiple_stages(trainer, model)
|
||||
|
@ -164,14 +221,20 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
|||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
|
||||
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
|
||||
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
||||
|
||||
# Test entry point
|
||||
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.test(ckpt_path=model_path)
|
||||
|
||||
# Predict entry point
|
||||
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.predict(ckpt_path=model_path)
|
||||
|
||||
|
||||
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
|
||||
# Use FullySharded to get the state dict for the sake of comparison
|
||||
|
|
Loading…
Reference in New Issue