Activation checkpointing in FSDP without boilerplate (#15826)
* initial * input type * checkpointing * fsdp in pl * all_close Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2992002beb
commit
05dbf48ad0
|
@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas
|
||||||
|
|
||||||
|
|
||||||
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
|
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
|
||||||
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4)
|
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
|
||||||
|
|
||||||
|
|
||||||
Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.
|
Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
|
|
||||||
|
Activation Checkpointing
|
||||||
|
========================
|
||||||
|
|
||||||
|
Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in
|
||||||
|
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
|
||||||
|
need to be recomputed.
|
||||||
|
|
||||||
|
Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
|
||||||
|
|
||||||
|
fsdp = DDPFullyShardedNativeStrategy(
|
||||||
|
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
|
||||||
|
)
|
||||||
|
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
|
||||||
|
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
|
|
||||||
.. _deepspeed_advanced:
|
.. _deepspeed_advanced:
|
||||||
|
|
||||||
*********
|
*********
|
||||||
|
|
|
@ -11,9 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import functools
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -35,7 +36,7 @@ from lightning_lite.utilities.distributed import (
|
||||||
)
|
)
|
||||||
from lightning_lite.utilities.distributed import group as _group
|
from lightning_lite.utilities.distributed import group as _group
|
||||||
from lightning_lite.utilities.distributed import ReduceOp
|
from lightning_lite.utilities.distributed import ReduceOp
|
||||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
|
||||||
from lightning_lite.utilities.rank_zero import rank_zero_only
|
from lightning_lite.utilities.rank_zero import rank_zero_only
|
||||||
from lightning_lite.utilities.seed import reset_seed
|
from lightning_lite.utilities.seed import reset_seed
|
||||||
|
|
||||||
|
@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||||
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
|
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
|
||||||
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
|
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
|
||||||
|
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
|
||||||
|
checkpointing. This is typically your transformer block (including attention + feed-forward).
|
||||||
|
Enabling this can free up a significant amount of memory at the cost of speed since activations in
|
||||||
|
these layers need to be recomputed during backpropagation.
|
||||||
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
|
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
|
||||||
when wrapping modules.
|
when wrapping modules.
|
||||||
"""
|
"""
|
||||||
|
@ -94,6 +99,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
cpu_offload: Optional["CPUOffload"] = None,
|
cpu_offload: Optional["CPUOffload"] = None,
|
||||||
backward_prefetch: Optional["BackwardPrefetch"] = None,
|
backward_prefetch: Optional["BackwardPrefetch"] = None,
|
||||||
mixed_precision: Optional["MixedPrecision"] = None,
|
mixed_precision: Optional["MixedPrecision"] = None,
|
||||||
|
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not _TORCH_GREATER_EQUAL_1_12:
|
if not _TORCH_GREATER_EQUAL_1_12:
|
||||||
|
@ -112,6 +118,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
self._backward_sync_control = _FSDPBackwardSyncControl()
|
self._backward_sync_control = _FSDPBackwardSyncControl()
|
||||||
self._ddp_kwargs = kwargs
|
self._ddp_kwargs = kwargs
|
||||||
|
|
||||||
|
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
|
||||||
|
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
|
||||||
|
activation_checkpointing = activation_checkpointing or []
|
||||||
|
self._activation_checkpointing = (
|
||||||
|
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
|
||||||
|
)
|
||||||
|
|
||||||
self.cpu_offload = cpu_offload
|
self.cpu_offload = cpu_offload
|
||||||
self.backward_prefetch = backward_prefetch
|
self.backward_prefetch = backward_prefetch
|
||||||
self.mixed_precision = mixed_precision
|
self.mixed_precision = mixed_precision
|
||||||
|
@ -175,13 +188,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||||
|
|
||||||
if (
|
if "auto_wrap_policy" in self._ddp_kwargs and any(
|
||||||
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
|
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
|
||||||
and "auto_wrap_policy" in self._ddp_kwargs
|
|
||||||
):
|
):
|
||||||
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
||||||
del self._ddp_kwargs["auto_wrap_policy"]
|
del self._ddp_kwargs["auto_wrap_policy"]
|
||||||
return FullyShardedDataParallel(
|
wrapped_module = FullyShardedDataParallel(
|
||||||
module=module,
|
module=module,
|
||||||
cpu_offload=self.cpu_offload,
|
cpu_offload=self.cpu_offload,
|
||||||
backward_prefetch=self.backward_prefetch,
|
backward_prefetch=self.backward_prefetch,
|
||||||
|
@ -190,6 +202,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
**self._ddp_kwargs,
|
**self._ddp_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# activation checkpointing needs to be set up after wrapping the model
|
||||||
|
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
|
||||||
|
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
|
||||||
|
|
||||||
|
return wrapped_module
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||||
"""Set up an optimizer for a model wrapped with FSDP.
|
"""Set up an optimizer for a model wrapped with FSDP.
|
||||||
|
|
||||||
|
@ -291,6 +309,21 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||||
rank_zero_only.rank = self.cluster_environment.global_rank()
|
rank_zero_only.rank = self.cluster_environment.global_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
|
||||||
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
apply_activation_checkpointing,
|
||||||
|
checkpoint_wrapper,
|
||||||
|
CheckpointImpl,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_fn = lambda submodule: isinstance(submodule, tuple(layers))
|
||||||
|
wrapper = functools.partial(
|
||||||
|
checkpoint_wrapper,
|
||||||
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||||
|
)
|
||||||
|
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
|
||||||
|
|
||||||
|
|
||||||
class _FSDPBackwardSyncControl(_BackwardSyncControl):
|
class _FSDPBackwardSyncControl(_BackwardSyncControl):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_backward_sync(self, module: Module) -> Generator:
|
def no_backward_sync(self, module: Module) -> Generator:
|
||||||
|
|
|
@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
|
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
|
||||||
|
|
||||||
|
|
||||||
|
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
|
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
|
||||||
|
|
|
@ -13,14 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||||
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params
|
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
|
||||||
from lightning_lite.utilities.distributed import (
|
from lightning_lite.utilities.distributed import (
|
||||||
_get_default_process_group_backend_for_device,
|
_get_default_process_group_backend_for_device,
|
||||||
_init_dist_connection,
|
_init_dist_connection,
|
||||||
|
@ -38,7 +39,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||||
from pytorch_lightning.trainer.states import TrainerFn
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
|
||||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||||
|
@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
||||||
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
|
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
|
||||||
or BF16 if ``precision=bf16`` unless a config is passed in.
|
or BF16 if ``precision=bf16`` unless a config is passed in.
|
||||||
This is only available in PyTorch 1.12 and later.
|
This is only available in PyTorch 1.12 and later.
|
||||||
|
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
|
||||||
|
checkpointing. This is typically your transformer block (including attention + feed-forward).
|
||||||
|
Enabling this can free up a significant amount of memory at the cost of speed since activations in
|
||||||
|
these layers need to be recomputed during backpropagation.
|
||||||
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
|
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -118,6 +123,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
||||||
cpu_offload: Optional[CPUOffload] = None,
|
cpu_offload: Optional[CPUOffload] = None,
|
||||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||||
mixed_precision: Optional[MixedPrecision] = None,
|
mixed_precision: Optional[MixedPrecision] = None,
|
||||||
|
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not _TORCH_GREATER_EQUAL_1_12:
|
if not _TORCH_GREATER_EQUAL_1_12:
|
||||||
|
@ -139,6 +145,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
||||||
self.backward_prefetch = backward_prefetch
|
self.backward_prefetch = backward_prefetch
|
||||||
self.mixed_precision = mixed_precision
|
self.mixed_precision = mixed_precision
|
||||||
self._rank_0_will_call_children_scripts: bool = False
|
self._rank_0_will_call_children_scripts: bool = False
|
||||||
|
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
|
||||||
|
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
|
||||||
|
activation_checkpointing = activation_checkpointing or []
|
||||||
|
self._activation_checkpointing = (
|
||||||
|
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
|
||||||
|
)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -209,15 +221,14 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
||||||
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||||
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
|
||||||
assert self.lightning_module is not None
|
assert self.lightning_module is not None
|
||||||
if (
|
if "auto_wrap_policy" in self.kwargs and any(
|
||||||
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
|
isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
|
||||||
and "auto_wrap_policy" in self.kwargs
|
|
||||||
):
|
):
|
||||||
del self.kwargs["auto_wrap_policy"]
|
del self.kwargs["auto_wrap_policy"]
|
||||||
|
|
||||||
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
|
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
|
||||||
|
|
||||||
return FullyShardedDataParallel(
|
wrapped_module = FullyShardedDataParallel(
|
||||||
module=model,
|
module=model,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
cpu_offload=self.cpu_offload,
|
cpu_offload=self.cpu_offload,
|
||||||
|
@ -227,6 +238,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# activation checkpointing needs to be set up after wrapping the model
|
||||||
|
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
|
||||||
|
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
|
||||||
|
|
||||||
|
return wrapped_module
|
||||||
|
|
||||||
def setup(self, trainer: "pl.Trainer") -> None:
|
def setup(self, trainer: "pl.Trainer") -> None:
|
||||||
assert self.accelerator is not None
|
assert self.accelerator is not None
|
||||||
self.accelerator.setup(trainer)
|
self.accelerator.setup(trainer)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, Mock
|
from unittest.mock import ANY, MagicMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
module.no_sync.assert_called_once()
|
module.no_sync.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(min_torch="1.12")
|
||||||
|
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
|
||||||
|
def test_fsdp_activation_checkpointing_support():
|
||||||
|
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
||||||
|
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
|
||||||
|
FSDPStrategy(activation_checkpointing=Mock())
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(min_torch="1.13")
|
||||||
|
def test_fsdp_activation_checkpointing():
|
||||||
|
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
||||||
|
|
||||||
|
class Block1(nn.Linear):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Block2(nn.Linear):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
||||||
|
self.layer1 = Block2(2, 2)
|
||||||
|
self.layer2 = nn.Linear(3, 3)
|
||||||
|
|
||||||
|
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
||||||
|
assert strategy._activation_checkpointing == [Block1]
|
||||||
|
|
||||||
|
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
||||||
|
assert strategy._activation_checkpointing == [Block1, Block2]
|
||||||
|
|
||||||
|
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||||
|
with mock.patch(
|
||||||
|
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
|
||||||
|
) as fsdp_mock, mock.patch(
|
||||||
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||||
|
) as ckpt_mock:
|
||||||
|
strategy.setup_module(Model())
|
||||||
|
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||||
|
|
|
@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path):
|
||||||
|
|
||||||
# model parameters are identical after loading
|
# model parameters are identical after loading
|
||||||
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
|
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
|
||||||
assert torch.equal(current_param.float().cpu(), loaded_param.cpu())
|
assert torch.allclose(current_param.float().cpu(), loaded_param.cpu())
|
||||||
|
|
||||||
|
|
||||||
def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
|
def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
|
||||||
|
@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par
|
||||||
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
||||||
def test_fsdp_train_save_load(manual_wrapping, precision):
|
def test_fsdp_train_save_load(manual_wrapping, precision):
|
||||||
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
|
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
|
||||||
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
|
strategy = FSDPStrategy(
|
||||||
|
auto_wrap_policy=_custom_auto_wrap_policy,
|
||||||
|
activation_checkpointing=[torch.nn.Linear],
|
||||||
|
)
|
||||||
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
|
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
|
||||||
lite.launch()
|
lite.launch()
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import ANY, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
@ -259,3 +262,47 @@ def test_invalid_parameters_in_optimizer(tmpdir):
|
||||||
model = NoFlatParametersModel()
|
model = NoFlatParametersModel()
|
||||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(min_torch="1.12")
|
||||||
|
@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False)
|
||||||
|
def test_fully_sharded_native_activation_checkpointing_support():
|
||||||
|
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
||||||
|
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
|
||||||
|
DDPFullyShardedNativeStrategy(activation_checkpointing=Mock())
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(min_torch="1.13")
|
||||||
|
def test_fully_sharded_native_activation_checkpointing():
|
||||||
|
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
||||||
|
|
||||||
|
class Block1(nn.Linear):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Block2(nn.Linear):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Model(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
||||||
|
self.layer1 = Block2(2, 2)
|
||||||
|
self.layer2 = nn.Linear(3, 3)
|
||||||
|
|
||||||
|
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1)
|
||||||
|
assert strategy._activation_checkpointing == [Block1]
|
||||||
|
|
||||||
|
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2])
|
||||||
|
assert strategy._activation_checkpointing == [Block1, Block2]
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||||
|
strategy._lightning_module = model
|
||||||
|
strategy._process_group = Mock()
|
||||||
|
with mock.patch(
|
||||||
|
"pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel"
|
||||||
|
) as fsdp_mock, mock.patch(
|
||||||
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||||
|
) as ckpt_mock:
|
||||||
|
strategy._setup_model(model)
|
||||||
|
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||||
|
|
Loading…
Reference in New Issue