Activation checkpointing in FSDP without boilerplate (#15826)

* initial
* input type
* checkpointing
* fsdp in pl
* all_close

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2022-12-06 16:45:33 +01:00 committed by GitHub
parent 2992002beb
commit 05dbf48ad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 185 additions and 17 deletions

View File

@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4) trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support. Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.
---- ----
Activation Checkpointing
========================
Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
need to be recomputed.
Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
.. code-block:: python
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
fsdp = DDPFullyShardedNativeStrategy(
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
----
.. _deepspeed_advanced: .. _deepspeed_advanced:
********* *********

View File

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -35,7 +36,7 @@ from lightning_lite.utilities.distributed import (
) )
from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.distributed import ReduceOp
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.rank_zero import rank_zero_only
from lightning_lite.utilities.seed import reset_seed from lightning_lite.utilities.seed import reset_seed
@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in
these layers need to be recomputed during backpropagation.
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
when wrapping modules. when wrapping modules.
""" """
@ -94,6 +99,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
cpu_offload: Optional["CPUOffload"] = None, cpu_offload: Optional["CPUOffload"] = None,
backward_prefetch: Optional["BackwardPrefetch"] = None, backward_prefetch: Optional["BackwardPrefetch"] = None,
mixed_precision: Optional["MixedPrecision"] = None, mixed_precision: Optional["MixedPrecision"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if not _TORCH_GREATER_EQUAL_1_12: if not _TORCH_GREATER_EQUAL_1_12:
@ -112,6 +118,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
self._backward_sync_control = _FSDPBackwardSyncControl() self._backward_sync_control = _FSDPBackwardSyncControl()
self._ddp_kwargs = kwargs self._ddp_kwargs = kwargs
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
self._activation_checkpointing = (
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
)
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.backward_prefetch = backward_prefetch self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
@ -175,13 +188,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
if ( if "auto_wrap_policy" in self._ddp_kwargs and any(
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
and "auto_wrap_policy" in self._ddp_kwargs
): ):
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy` # If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
del self._ddp_kwargs["auto_wrap_policy"] del self._ddp_kwargs["auto_wrap_policy"]
return FullyShardedDataParallel( wrapped_module = FullyShardedDataParallel(
module=module, module=module,
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch, backward_prefetch=self.backward_prefetch,
@ -190,6 +202,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
**self._ddp_kwargs, **self._ddp_kwargs,
) )
# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
return wrapped_module
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Set up an optimizer for a model wrapped with FSDP. """Set up an optimizer for a model wrapped with FSDP.
@ -291,6 +309,21 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
rank_zero_only.rank = self.cluster_environment.global_rank() rank_zero_only.rank = self.cluster_environment.global_rank()
def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
check_fn = lambda submodule: isinstance(submodule, tuple(layers))
wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
class _FSDPBackwardSyncControl(_BackwardSyncControl): class _FSDPBackwardSyncControl(_BackwardSyncControl):
@contextmanager @contextmanager
def no_backward_sync(self, module: Module) -> Generator: def no_backward_sync(self, module: Module) -> Generator:

View File

@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814)) - Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
### Changed ### Changed
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

View File

@ -13,14 +13,15 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import logging import logging
from typing import Any, Dict, Generator, List, Optional, Union from typing import Any, Dict, Generator, List, Optional, Type, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module
import pytorch_lightning as pl import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
from lightning_lite.utilities.distributed import ( from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device, _get_default_process_group_backend_for_device,
_init_dist_connection, _init_dist_connection,
@ -38,7 +39,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
or BF16 if ``precision=bf16`` unless a config is passed in. or BF16 if ``precision=bf16`` unless a config is passed in.
This is only available in PyTorch 1.12 and later. This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in
these layers need to be recomputed during backpropagation.
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
""" """
@ -118,6 +123,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
cpu_offload: Optional[CPUOffload] = None, cpu_offload: Optional[CPUOffload] = None,
backward_prefetch: Optional[BackwardPrefetch] = None, backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None, mixed_precision: Optional[MixedPrecision] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if not _TORCH_GREATER_EQUAL_1_12: if not _TORCH_GREATER_EQUAL_1_12:
@ -139,6 +145,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
self.backward_prefetch = backward_prefetch self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False self._rank_0_will_call_children_scripts: bool = False
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
self._activation_checkpointing = (
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
)
self.kwargs = kwargs self.kwargs = kwargs
@property @property
@ -209,15 +221,14 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy` # If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None assert self.lightning_module is not None
if ( if "auto_wrap_policy" in self.kwargs and any(
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()) isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
and "auto_wrap_policy" in self.kwargs
): ):
del self.kwargs["auto_wrap_policy"] del self.kwargs["auto_wrap_policy"]
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
return FullyShardedDataParallel( wrapped_module = FullyShardedDataParallel(
module=model, module=model,
process_group=self.process_group, process_group=self.process_group,
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
@ -227,6 +238,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
**self.kwargs, **self.kwargs,
) )
# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
return wrapped_module
def setup(self, trainer: "pl.Trainer") -> None: def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None assert self.accelerator is not None
self.accelerator.setup(trainer) self.accelerator.setup(trainer)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, Mock from unittest.mock import ANY, MagicMock, Mock
import pytest import pytest
import torch import torch
@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync():
pass pass
module.no_sync.assert_called_once() module.no_sync.assert_called_once()
@RunIf(min_torch="1.12")
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
def test_fsdp_activation_checkpointing_support():
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
FSDPStrategy(activation_checkpointing=Mock())
@RunIf(min_torch="1.13")
def test_fsdp_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
pass
class Block2(nn.Linear):
pass
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert strategy._activation_checkpointing == [Block1]
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert strategy._activation_checkpointing == [Block1, Block2]
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch(
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
) as fsdp_mock, mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as ckpt_mock:
strategy.setup_module(Model())
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)

View File

@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path):
# model parameters are identical after loading # model parameters are identical after loading
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) assert torch.allclose(current_param.float().cpu(), loaded_param.cpu())
def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par
@pytest.mark.parametrize("manual_wrapping", [True, False]) @pytest.mark.parametrize("manual_wrapping", [True, False])
def test_fsdp_train_save_load(manual_wrapping, precision): def test_fsdp_train_save_load(manual_wrapping, precision):
"""Test FSDP training, saving and loading with different wrapping and precision settings.""" """Test FSDP training, saving and loading with different wrapping and precision settings."""
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) strategy = FSDPStrategy(
auto_wrap_policy=_custom_auto_wrap_policy,
activation_checkpointing=[torch.nn.Linear],
)
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
lite.launch() lite.launch()

View File

@ -1,8 +1,11 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from unittest import mock
from unittest.mock import ANY, Mock
import pytest import pytest
import torch import torch
import torch.nn as nn
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
@ -259,3 +262,47 @@ def test_invalid_parameters_in_optimizer(tmpdir):
model = NoFlatParametersModel() model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model) trainer.fit(model)
@RunIf(min_torch="1.12")
@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False)
def test_fully_sharded_native_activation_checkpointing_support():
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
DDPFullyShardedNativeStrategy(activation_checkpointing=Mock())
@RunIf(min_torch="1.13")
def test_fully_sharded_native_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
pass
class Block2(nn.Linear):
pass
class Model(BoringModel):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1)
assert strategy._activation_checkpointing == [Block1]
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2])
assert strategy._activation_checkpointing == [Block1, Block2]
model = Model()
strategy._parallel_devices = [torch.device("cuda", 0)]
strategy._lightning_module = model
strategy._process_group = Mock()
with mock.patch(
"pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel"
) as fsdp_mock, mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as ckpt_mock:
strategy._setup_model(model)
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)