Activation checkpointing in FSDP without boilerplate (#15826)
* initial * input type * checkpointing * fsdp in pl * all_close Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2992002beb
commit
05dbf48ad0
|
@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas
|
|||
|
||||
|
||||
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
|
||||
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4)
|
||||
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
|
||||
|
||||
|
||||
Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.
|
||||
|
||||
----
|
||||
|
||||
|
||||
Activation Checkpointing
|
||||
========================
|
||||
|
||||
Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in
|
||||
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
|
||||
need to be recomputed.
|
||||
|
||||
Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
|
||||
|
||||
fsdp = DDPFullyShardedNativeStrategy(
|
||||
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
|
||||
)
|
||||
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
.. _deepspeed_advanced:
|
||||
|
||||
*********
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -35,7 +36,7 @@ from lightning_lite.utilities.distributed import (
|
|||
)
|
||||
from lightning_lite.utilities.distributed import group as _group
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
|
||||
from lightning_lite.utilities.rank_zero import rank_zero_only
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
|
||||
|
@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
|
||||
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
|
||||
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
|
||||
checkpointing. This is typically your transformer block (including attention + feed-forward).
|
||||
Enabling this can free up a significant amount of memory at the cost of speed since activations in
|
||||
these layers need to be recomputed during backpropagation.
|
||||
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
|
||||
when wrapping modules.
|
||||
"""
|
||||
|
@ -94,6 +99,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
cpu_offload: Optional["CPUOffload"] = None,
|
||||
backward_prefetch: Optional["BackwardPrefetch"] = None,
|
||||
mixed_precision: Optional["MixedPrecision"] = None,
|
||||
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
|
@ -112,6 +118,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
self._backward_sync_control = _FSDPBackwardSyncControl()
|
||||
self._ddp_kwargs = kwargs
|
||||
|
||||
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
|
||||
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
|
||||
activation_checkpointing = activation_checkpointing or []
|
||||
self._activation_checkpointing = (
|
||||
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
|
||||
)
|
||||
|
||||
self.cpu_offload = cpu_offload
|
||||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
|
@ -175,13 +188,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
|
||||
if (
|
||||
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
|
||||
and "auto_wrap_policy" in self._ddp_kwargs
|
||||
if "auto_wrap_policy" in self._ddp_kwargs and any(
|
||||
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
|
||||
):
|
||||
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
||||
del self._ddp_kwargs["auto_wrap_policy"]
|
||||
return FullyShardedDataParallel(
|
||||
wrapped_module = FullyShardedDataParallel(
|
||||
module=module,
|
||||
cpu_offload=self.cpu_offload,
|
||||
backward_prefetch=self.backward_prefetch,
|
||||
|
@ -190,6 +202,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
**self._ddp_kwargs,
|
||||
)
|
||||
|
||||
# activation checkpointing needs to be set up after wrapping the model
|
||||
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
|
||||
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
|
||||
|
||||
return wrapped_module
|
||||
|
||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Set up an optimizer for a model wrapped with FSDP.
|
||||
|
||||
|
@ -291,6 +309,21 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
rank_zero_only.rank = self.cluster_environment.global_rank()
|
||||
|
||||
|
||||
def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
|
||||
check_fn = lambda submodule: isinstance(submodule, tuple(layers))
|
||||
wrapper = functools.partial(
|
||||
checkpoint_wrapper,
|
||||
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||
)
|
||||
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
|
||||
|
||||
|
||||
class _FSDPBackwardSyncControl(_BackwardSyncControl):
|
||||
@contextmanager
|
||||
def no_backward_sync(self, module: Module) -> Generator:
|
||||
|
|
|
@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
|
||||
|
||||
|
||||
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
|
||||
|
|
|
@ -13,14 +13,15 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params
|
||||
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_default_process_group_backend_for_device,
|
||||
_init_dist_connection,
|
||||
|
@ -38,7 +39,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
|
|||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
|
||||
or BF16 if ``precision=bf16`` unless a config is passed in.
|
||||
This is only available in PyTorch 1.12 and later.
|
||||
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
|
||||
checkpointing. This is typically your transformer block (including attention + feed-forward).
|
||||
Enabling this can free up a significant amount of memory at the cost of speed since activations in
|
||||
these layers need to be recomputed during backpropagation.
|
||||
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
|
||||
|
||||
"""
|
||||
|
@ -118,6 +123,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
cpu_offload: Optional[CPUOffload] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
|
@ -139,6 +145,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
|
||||
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
|
||||
activation_checkpointing = activation_checkpointing or []
|
||||
self._activation_checkpointing = (
|
||||
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
|
||||
)
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
|
@ -209,15 +221,14 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
||||
assert self.lightning_module is not None
|
||||
if (
|
||||
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
|
||||
and "auto_wrap_policy" in self.kwargs
|
||||
if "auto_wrap_policy" in self.kwargs and any(
|
||||
isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
|
||||
):
|
||||
del self.kwargs["auto_wrap_policy"]
|
||||
|
||||
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
|
||||
|
||||
return FullyShardedDataParallel(
|
||||
wrapped_module = FullyShardedDataParallel(
|
||||
module=model,
|
||||
process_group=self.process_group,
|
||||
cpu_offload=self.cpu_offload,
|
||||
|
@ -227,6 +238,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
**self.kwargs,
|
||||
)
|
||||
|
||||
# activation checkpointing needs to be set up after wrapping the model
|
||||
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
|
||||
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
|
||||
|
||||
return wrapped_module
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync():
|
|||
pass
|
||||
|
||||
module.no_sync.assert_called_once()
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
|
||||
def test_fsdp_activation_checkpointing_support():
|
||||
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
||||
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
|
||||
FSDPStrategy(activation_checkpointing=Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="1.13")
|
||||
def test_fsdp_activation_checkpointing():
|
||||
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
||||
|
||||
class Block1(nn.Linear):
|
||||
pass
|
||||
|
||||
class Block2(nn.Linear):
|
||||
pass
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
||||
self.layer1 = Block2(2, 2)
|
||||
self.layer2 = nn.Linear(3, 3)
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
||||
assert strategy._activation_checkpointing == [Block1]
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
||||
assert strategy._activation_checkpointing == [Block1, Block2]
|
||||
|
||||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
with mock.patch(
|
||||
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
|
||||
) as fsdp_mock, mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as ckpt_mock:
|
||||
strategy.setup_module(Model())
|
||||
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||
|
|
|
@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path):
|
|||
|
||||
# model parameters are identical after loading
|
||||
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
|
||||
assert torch.equal(current_param.float().cpu(), loaded_param.cpu())
|
||||
assert torch.allclose(current_param.float().cpu(), loaded_param.cpu())
|
||||
|
||||
|
||||
def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
|
||||
|
@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par
|
|||
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
||||
def test_fsdp_train_save_load(manual_wrapping, precision):
|
||||
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
|
||||
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
|
||||
strategy = FSDPStrategy(
|
||||
auto_wrap_policy=_custom_auto_wrap_policy,
|
||||
activation_checkpointing=[torch.nn.Linear],
|
||||
)
|
||||
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
|
||||
lite.launch()
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
@ -259,3 +262,47 @@ def test_invalid_parameters_in_optimizer(tmpdir):
|
|||
model = NoFlatParametersModel()
|
||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False)
|
||||
def test_fully_sharded_native_activation_checkpointing_support():
|
||||
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
||||
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
|
||||
DDPFullyShardedNativeStrategy(activation_checkpointing=Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="1.13")
|
||||
def test_fully_sharded_native_activation_checkpointing():
|
||||
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
||||
|
||||
class Block1(nn.Linear):
|
||||
pass
|
||||
|
||||
class Block2(nn.Linear):
|
||||
pass
|
||||
|
||||
class Model(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
||||
self.layer1 = Block2(2, 2)
|
||||
self.layer2 = nn.Linear(3, 3)
|
||||
|
||||
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1)
|
||||
assert strategy._activation_checkpointing == [Block1]
|
||||
|
||||
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2])
|
||||
assert strategy._activation_checkpointing == [Block1, Block2]
|
||||
|
||||
model = Model()
|
||||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
strategy._lightning_module = model
|
||||
strategy._process_group = Mock()
|
||||
with mock.patch(
|
||||
"pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel"
|
||||
) as fsdp_mock, mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as ckpt_mock:
|
||||
strategy._setup_model(model)
|
||||
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||
|
|
Loading…
Reference in New Issue