From 05dbf48ad0b6f9eed79cee8ce8a61839c219f330 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Dec 2022 16:45:33 +0100 Subject: [PATCH] Activation checkpointing in FSDP without boilerplate (#15826) * initial * input type * checkpointing * fsdp in pl * all_close Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../advanced/model_parallel.rst | 25 +++++++++- src/lightning_lite/strategies/fsdp.py | 45 +++++++++++++++--- src/pytorch_lightning/CHANGELOG.md | 4 ++ .../strategies/fully_sharded_native.py | 31 +++++++++--- tests/tests_lite/strategies/test_fsdp.py | 43 ++++++++++++++++- .../strategies/test_fsdp_integration.py | 7 ++- .../test_ddp_fully_sharded_native.py | 47 +++++++++++++++++++ 7 files changed, 185 insertions(+), 17 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index ab96b5339b..d2c86db5ba 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) - trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4) + trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4) Check out `this tutorial `__ to learn more about the native support. ---- + +Activation Checkpointing +======================== + +Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in +selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations +need to be recomputed. + +Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy: + +.. code-block:: python + + from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy + + fsdp = DDPFullyShardedNativeStrategy( + activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types + ) + trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) + + +---- + + .. _deepspeed_advanced: ********* diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 46a36bf95b..9e19a5b77c 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union import torch from torch import Tensor @@ -35,7 +36,7 @@ from lightning_lite.utilities.distributed import ( ) from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import ReduceOp -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.seed import reset_seed @@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation + checkpointing. This is typically your transformer block (including attention + feed-forward). + Enabling this can free up a significant amount of memory at the cost of speed since activations in + these layers need to be recomputed during backpropagation. \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -94,6 +99,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): cpu_offload: Optional["CPUOffload"] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -112,6 +118,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded): self._backward_sync_control = _FSDPBackwardSyncControl() self._ddp_kwargs = kwargs + if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: + raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") + activation_checkpointing = activation_checkpointing or [] + self._activation_checkpointing = ( + [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + ) + self.cpu_offload = cpu_offload self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision @@ -175,13 +188,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded): :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - if ( - any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) - and "auto_wrap_policy" in self._ddp_kwargs + if "auto_wrap_policy" in self._ddp_kwargs and any( + isinstance(mod, FullyShardedDataParallel) for mod in module.modules() ): # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` del self._ddp_kwargs["auto_wrap_policy"] - return FullyShardedDataParallel( + wrapped_module = FullyShardedDataParallel( module=module, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, @@ -190,6 +202,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded): **self._ddp_kwargs, ) + # activation checkpointing needs to be set up after wrapping the model + if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: + _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + + return wrapped_module + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Set up an optimizer for a model wrapped with FSDP. @@ -291,6 +309,21 @@ class FSDPStrategy(ParallelStrategy, _Sharded): rank_zero_only.rank = self.cluster_environment.global_rank() +def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, + ) + + check_fn = lambda submodule: isinstance(submodule, tuple(layers)) + wrapper = functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + + class _FSDPBackwardSyncControl(_BackwardSyncControl): @contextmanager def no_backward_sync(self, module: Module) -> Generator: diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 90d45428e9..3537cad330 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814)) + +- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 69110db455..38ed803235 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -13,14 +13,15 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Type, Union import torch from torch import Tensor +from torch.nn import Module import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params +from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -38,7 +39,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation + checkpointing. This is typically your transformer block (including attention + feed-forward). + Enabling this can free up a significant amount of memory at the cost of speed since activations in + these layers need to be recomputed during backpropagation. \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -118,6 +123,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): cpu_offload: Optional[CPUOffload] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -139,6 +145,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False + if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: + raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") + activation_checkpointing = activation_checkpointing or [] + self._activation_checkpointing = ( + [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + ) self.kwargs = kwargs @property @@ -209,15 +221,14 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` assert self.lightning_module is not None - if ( - any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()) - and "auto_wrap_policy" in self.kwargs + if "auto_wrap_policy" in self.kwargs and any( + isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules() ): del self.kwargs["auto_wrap_policy"] log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") - return FullyShardedDataParallel( + wrapped_module = FullyShardedDataParallel( module=model, process_group=self.process_group, cpu_offload=self.cpu_offload, @@ -227,6 +238,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): **self.kwargs, ) + # activation checkpointing needs to be set up after wrapping the model + if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: + _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + + return wrapped_module + def setup(self, trainer: "pl.Trainer") -> None: assert self.accelerator is not None self.accelerator.setup(trainer) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 8f609d53c2..62880ea6e5 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import ANY, MagicMock, Mock import pytest import torch @@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync(): pass module.no_sync.assert_called_once() + + +@RunIf(min_torch="1.12") +@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) +def test_fsdp_activation_checkpointing_support(): + """Test that we error out if activation checkpointing requires a newer PyTorch version.""" + with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + FSDPStrategy(activation_checkpointing=Mock()) + + +@RunIf(min_torch="1.13") +def test_fsdp_activation_checkpointing(): + """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + + class Block1(nn.Linear): + pass + + class Block2(nn.Linear): + pass + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) + self.layer1 = Block2(2, 2) + self.layer2 = nn.Linear(3, 3) + + strategy = FSDPStrategy(activation_checkpointing=Block1) + assert strategy._activation_checkpointing == [Block1] + + strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) + assert strategy._activation_checkpointing == [Block1, Block2] + + strategy._parallel_devices = [torch.device("cuda", 0)] + with mock.patch( + "torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel" + ) as fsdp_mock, mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as ckpt_mock: + strategy.setup_module(Model()) + ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 963918d076..0289ac7241 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path): # model parameters are identical after loading for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): - assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) + assert torch.allclose(current_param.float().cpu(), loaded_param.cpu()) def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: @@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test FSDP training, saving and loading with different wrapping and precision settings.""" - strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + strategy = FSDPStrategy( + auto_wrap_policy=_custom_auto_wrap_policy, + activation_checkpointing=[torch.nn.Linear], + ) lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.launch() diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 5bb6b84d9e..a9b47aad1d 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -1,8 +1,11 @@ import os from typing import Any, Dict, Optional +from unittest import mock +from unittest.mock import ANY, Mock import pytest import torch +import torch.nn as nn from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -259,3 +262,47 @@ def test_invalid_parameters_in_optimizer(tmpdir): model = NoFlatParametersModel() with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): trainer.fit(model) + + +@RunIf(min_torch="1.12") +@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False) +def test_fully_sharded_native_activation_checkpointing_support(): + """Test that we error out if activation checkpointing requires a newer PyTorch version.""" + with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + DDPFullyShardedNativeStrategy(activation_checkpointing=Mock()) + + +@RunIf(min_torch="1.13") +def test_fully_sharded_native_activation_checkpointing(): + """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + + class Block1(nn.Linear): + pass + + class Block2(nn.Linear): + pass + + class Model(BoringModel): + def __init__(self): + super().__init__() + self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) + self.layer1 = Block2(2, 2) + self.layer2 = nn.Linear(3, 3) + + strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1) + assert strategy._activation_checkpointing == [Block1] + + strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2]) + assert strategy._activation_checkpointing == [Block1, Block2] + + model = Model() + strategy._parallel_devices = [torch.device("cuda", 0)] + strategy._lightning_module = model + strategy._process_group = Mock() + with mock.patch( + "pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel" + ) as fsdp_mock, mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as ckpt_mock: + strategy._setup_model(model) + ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)