diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index d1db58b3f1..50ae2cd282 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -212,14 +212,31 @@ PyTorch Fully Sharded Training ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PyTorch has it's own version of `FSDP `_ which is upstreamed from their `fairscale `__ project. -It was introduced in their `v1.11.0 release `_. The API is pretty similar to that of FairScale. +It was introduced in their `v1.11.0 release `_ but it is recommended to use it with PyTorch v1.12 or more and that's what +Lightning supports. The API is pretty similar to that of FairScale. -.. note:: - Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``. - This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``. - This is a limitation of Fully Sharded Training that will be resolved in the future. -To activate parameter sharding, you must wrap your model using the``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly. +Auto Wrapping +""""""""""""" +Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The +simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't +have to ``wrap`` layers manually as in the case of manual wrapping. + +.. code-block:: python + + model = BoringModel() + trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16) + trainer.fit(model) + + +Read more `here `__. + + +Manual Wrapping +""""""""""""""" + +Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate +parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly. When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies. diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 8dcd45f58b..6ca49228cb 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290)) +- Added support for auto wrapping for `DDPFullyShardedNativeStrategy` ([#14252](https://github.com/Lightning-AI/lightning/issues/14252)) + + - Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185)) diff --git a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py index 2201db9458..f91144124a 100644 --- a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py @@ -25,14 +25,13 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler else: MixedPrecision = None # type: ignore[misc,assignment] + ShardedGradScaler = None # type: ignore[misc,assignment] class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Native AMP for Fully Sharded Native Training.""" - def __init__( - self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None - ) -> None: + def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException( "`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards." diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index 570e25bd85..15c23e18ed 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -13,8 +13,6 @@ # limitations under the License. from typing import Optional, Union -import torch - from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE @@ -29,9 +27,7 @@ else: class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Native AMP for Sharded Training.""" - def __init__( - self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None - ) -> None: + def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: if not _FAIRSCALE_AVAILABLE: raise MisconfigurationException( "You have asked for sharded AMP but you have not installed it." diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 4dbf36e4c2..456ddd36f9 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -20,6 +20,7 @@ from torch import Tensor from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -38,9 +39,11 @@ from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import ( @@ -51,6 +54,7 @@ if _TORCH_GREATER_EQUAL_1_12: ) from torch.distributed.fsdp.wrap import enable_wrap else: + FullyShardedDataParallel = None # type: ignore[misc,assignment] MixedPrecision = None # type: ignore[misc,assignment] BackwardPrefetch = None # type: ignore[misc,assignment] CPUOffload = None # type: ignore[misc,assignment] @@ -201,6 +205,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) self._rank_0_will_call_children_scripts = True + def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: + """Wraps the model into a + :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` + assert self.lightning_module is not None + if ( + any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()) + and "auto_wrap_policy" in self.kwargs + ): + del self.kwargs["auto_wrap_policy"] + + log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") + return FullyShardedDataParallel( + module=model, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self.kwargs, + ) + def setup(self, trainer: "pl.Trainer") -> None: assert self.accelerator is not None self.accelerator.setup(trainer) @@ -215,9 +241,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): assert self.lightning_module is not None self.lightning_module._device = self.root_device + assert isinstance(self.model, pl.LightningModule) + self.model = _LightningModuleWrapperBase(self.model) + if is_overridden("configure_sharded_model", self.lightning_module): + rank_zero_info( + "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" + " are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`." + ) + else: + self.model = self._setup_model(self.model) self.barrier() + self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() def model_to_device(self) -> None: @@ -273,6 +310,24 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + # we don't need precision context since casting is done by FSDP + # read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html + assert self.model is not None + return self.model(*args, **kwargs) + + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None + return self.model(*args, **kwargs) + + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None + return self.model(*args, **kwargs) + + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None + return self.model(*args, **kwargs) + def _determine_device_ids(self) -> List[int]: return [self.root_device.index] diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index ede201da1f..be8bced2cb 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -11,7 +11,6 @@ from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShar from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -from pytorch_lightning.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: @@ -19,6 +18,15 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.wrap import wrap +def custom_auto_wrap_policy( + module, + recurse, + unwrapped_params: int, + min_num_params: int = int(1e8), +) -> bool: + return unwrapped_params >= 2 + + @RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" @@ -78,38 +86,73 @@ class TestFSDPModel(BoringModel): def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: + def on_train_batch_end(self, outputs, batch, batch_idx) -> None: self._assert_layer_fsdp_instance() - def on_test_batch_end( - self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: self._assert_layer_fsdp_instance() - def on_validation_batch_end( - self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: self._assert_layer_fsdp_instance() - def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin) - assert isinstance(self.layer.module[0], FullyShardedDataParallel) - assert isinstance(self.layer.module[2], FullyShardedDataParallel) # root should not be resharding assert self.layer.reshard_after_forward is False - # Assert that the nested layers are set reshard_after_forward to True - assert self.layer.module[0].reshard_after_forward is True - assert self.layer.module[2].reshard_after_forward is True precision = torch.float16 if self.precision == 16 else torch.bfloat16 assert self.layer.mixed_precision.param_dtype == precision assert self.layer.mixed_precision.reduce_dtype == precision assert self.layer.mixed_precision.buffer_dtype == precision + for layer_num in [0, 2]: + assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel) + # Assert that the nested layers are set reshard_after_forward to True + assert self.layer.module[layer_num].reshard_after_forward is True + + assert self.layer[layer_num].mixed_precision.param_dtype == precision + assert self.layer[layer_num].mixed_precision.reduce_dtype == precision + assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + + +class TestFSDPModelAutoWrapped(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + def on_train_batch_end(self, outputs, batch, batch_idx) -> None: + self._assert_layer_fsdp_instance() + + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: + self._assert_layer_fsdp_instance() + + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: + self._assert_layer_fsdp_instance() + + def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: + self._assert_layer_fsdp_instance() + + def _assert_layer_fsdp_instance(self) -> None: + assert isinstance(self.layer, torch.nn.Sequential) + assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin) + + precision = torch.float16 if self.precision == 16 else torch.bfloat16 + for layer_num in [0, 2]: + assert isinstance(self.layer[layer_num], FullyShardedDataParallel) + # Assert that the nested layers are set reshard_after_forward to True + assert self.layer[layer_num].reshard_after_forward + + assert self.layer[layer_num].mixed_precision.param_dtype == precision + assert self.layer[layer_num].mixed_precision.reduce_dtype == precision + assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): @@ -140,18 +183,32 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") -def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): +@pytest.mark.parametrize( + "model, strategy", + [ + (TestFSDPModel(), "fsdp_native"), + (TestFSDPModelAutoWrapped(), DDPFullyShardedNativeStrategy), + ], +) +def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, strategy): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" - model = TestFSDPModel() ck = ModelCheckpoint(save_last=True) + + if not isinstance(strategy, str): + strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy) + trainer = Trainer( default_root_dir=tmpdir, accelerator="gpu", devices=2, - strategy="fsdp_native", + strategy=strategy, precision=16, max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, callbacks=[ck], ) _run_multiple_stages(trainer, model) @@ -164,14 +221,20 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.save_checkpoint(model_path, weights_only=True) - _assert_save_equality(trainer, model_path, cls=TestFSDPModel) + _assert_save_equality(trainer, model_path, cls=model.__class__) # Test entry point trainer.test(model) # model is wrapped, will not call `configure_sharded_model` - # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap trainer.test(ckpt_path=model_path) + # Predict entry point + trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.predict(ckpt_path=model_path) + def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): # Use FullySharded to get the state dict for the sake of comparison