Simplify enabling CPU offload in FSDP (#15832)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
852089e056
commit
2debd1c6b6
|
@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
|
||||
|
||||
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
|
||||
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True)
|
||||
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
|
||||
|
||||
|
||||
|
|
|
@ -69,11 +69,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
|
||||
|
||||
Arguments:
|
||||
cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
|
||||
can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
|
||||
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
|
||||
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
|
||||
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
|
||||
to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
|
||||
will be no offloading.
|
||||
to work with the optimizer. This API is subject to change. Default: no offoading
|
||||
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
|
||||
users to enable two different backward prefetching algorithms to help backward communication and
|
||||
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||
|
@ -96,7 +95,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
precision: Optional[Precision] = None,
|
||||
process_group_backend: Optional[str] = None,
|
||||
timeout: Optional[timedelta] = default_pg_timeout,
|
||||
cpu_offload: Optional["CPUOffload"] = None,
|
||||
cpu_offload: Union[bool, "CPUOffload", None] = None,
|
||||
backward_prefetch: Optional["BackwardPrefetch"] = None,
|
||||
mixed_precision: Optional["MixedPrecision"] = None,
|
||||
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||
|
@ -125,7 +124,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
|
||||
)
|
||||
|
||||
self.cpu_offload = cpu_offload
|
||||
self.cpu_offload = _init_cpu_offload(cpu_offload)
|
||||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
|
@ -276,7 +275,6 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
|
||||
return
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
|
||||
strategy_registry.register(
|
||||
"fsdp",
|
||||
|
@ -287,7 +285,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
"fsdp_full_shard_offload",
|
||||
cls,
|
||||
description="Native FSDP with Full Sharding and CPU Offloading",
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
cpu_offload=True,
|
||||
)
|
||||
|
||||
def _setup_distributed(self) -> None:
|
||||
|
@ -341,6 +339,12 @@ class _FSDPBackwardSyncControl(_BackwardSyncControl):
|
|||
yield
|
||||
|
||||
|
||||
def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload":
|
||||
from torch.distributed.fsdp import CPUOffload
|
||||
|
||||
return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload))
|
||||
|
||||
|
||||
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
|
||||
from torch.distributed.fsdp import FlatParameter
|
||||
|
||||
|
|
|
@ -33,6 +33,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
|
||||
|
||||
|
||||
|
||||
- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
|
||||
|
|
|
@ -21,7 +21,11 @@ from torch.nn import Module
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
|
||||
from lightning_lite.strategies.fsdp import (
|
||||
_init_cpu_offload,
|
||||
_optimizer_has_flat_params,
|
||||
_setup_activation_checkpointing,
|
||||
)
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_default_process_group_backend_for_device,
|
||||
_init_dist_connection,
|
||||
|
@ -84,14 +88,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
|
||||
|
||||
Arguments:
|
||||
cpu_offload:
|
||||
CPU offloading config. Currently, only parameter and gradient CPU
|
||||
offload is supported. It can be enabled via passing in
|
||||
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
|
||||
currently implicitly enables gradient offloading to CPU in order for
|
||||
params and grads to be on same device to work with optimizer. This
|
||||
API is subject to change. Default is ``None`` in which case there
|
||||
will be no offloading.
|
||||
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
|
||||
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
|
||||
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
|
||||
to work with the optimizer. This API is subject to change. Default: no offoading
|
||||
backward_prefetch:
|
||||
This is an experimental feature that is subject to change in the
|
||||
the near future. It allows users to enable two different backward_prefetch
|
||||
|
@ -120,7 +120,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
process_group_backend: Optional[str] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
cpu_offload: Union[bool, "CPUOffload", None] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
|
||||
|
@ -141,7 +141,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
self._process_group = None
|
||||
self.num_nodes = 1
|
||||
self._process_group_backend = process_group_backend
|
||||
self.cpu_offload = cpu_offload
|
||||
self.cpu_offload = _init_cpu_offload(cpu_offload)
|
||||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
|
@ -403,6 +403,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
"fsdp_native_full_shard_offload",
|
||||
cls,
|
||||
description="Native FSDP with Full Sharding and CPU Offloading",
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
cpu_offload=True,
|
||||
)
|
||||
cls._registered_strategies.append("fsdp_native_full_shard_offload")
|
||||
|
|
|
@ -26,7 +26,7 @@ from lightning_lite.strategies.fsdp import _FSDPBackwardSyncControl
|
|||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
|
||||
|
@ -36,13 +36,26 @@ def test_fsdp_support(*_):
|
|||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_custom_mixed_precision(*_):
|
||||
def test_fsdp_custom_mixed_precision():
|
||||
"""Test that passing a custom mixed precision config works."""
|
||||
config = MixedPrecision()
|
||||
strategy = FSDPStrategy(mixed_precision=config)
|
||||
assert strategy.mixed_precision_config == config
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_cpu_offload():
|
||||
"""Test the different ways cpu offloading can be enabled."""
|
||||
# bool
|
||||
strategy = FSDPStrategy(cpu_offload=True)
|
||||
assert strategy.cpu_offload == CPUOffload(offload_params=True)
|
||||
|
||||
# dataclass
|
||||
config = CPUOffload()
|
||||
strategy = FSDPStrategy(cpu_offload=config)
|
||||
assert strategy.cpu_offload == config
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_setup_optimizer_validation():
|
||||
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
|
||||
|
|
|
@ -17,7 +17,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import wrap
|
||||
|
||||
|
||||
|
@ -306,3 +306,16 @@ def test_fully_sharded_native_activation_checkpointing():
|
|||
) as ckpt_mock:
|
||||
strategy._setup_model(model)
|
||||
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_cpu_offload():
|
||||
"""Test the different ways cpu offloading can be enabled."""
|
||||
# bool
|
||||
strategy = DDPFullyShardedNativeStrategy(cpu_offload=True)
|
||||
assert strategy.cpu_offload == CPUOffload(offload_params=True)
|
||||
|
||||
# dataclass
|
||||
config = CPUOffload()
|
||||
strategy = DDPFullyShardedNativeStrategy(cpu_offload=config)
|
||||
assert strategy.cpu_offload == config
|
||||
|
|
Loading…
Reference in New Issue