From 2debd1c6b6b35d3142eb5614369e0b91288b6b7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Dec 2022 03:55:47 +0100 Subject: [PATCH] Simplify enabling CPU offload in FSDP (#15832) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .../advanced/model_parallel.rst | 3 +-- src/lightning_lite/strategies/fsdp.py | 20 +++++++++------- src/pytorch_lightning/CHANGELOG.md | 4 ++++ .../strategies/fully_sharded_native.py | 24 +++++++++---------- tests/tests_lite/strategies/test_fsdp.py | 17 +++++++++++-- .../test_ddp_fully_sharded_native.py | 15 +++++++++++- 6 files changed, 58 insertions(+), 25 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index d2c86db5ba..db5605619a 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas from pytorch_lightning import Trainer from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) + native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True) trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 9e19a5b77c..d333a72c3d 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -69,11 +69,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): `this tutorial `__ for more information. Arguments: - cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It - can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device - to work with the optimizer. This API is subject to change. Default is ``None`` in which case there - will be no offloading. + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows users to enable two different backward prefetching algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. @@ -96,7 +95,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Optional["CPUOffload"] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, @@ -125,7 +124,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing ) - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision @@ -276,7 +275,6 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def register_strategies(cls, strategy_registry: Dict) -> None: if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): return - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload strategy_registry.register( "fsdp", @@ -287,7 +285,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): "fsdp_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) def _setup_distributed(self) -> None: @@ -341,6 +339,12 @@ class _FSDPBackwardSyncControl(_BackwardSyncControl): yield +def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": + from torch.distributed.fsdp import CPUOffload + + return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) + + def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: from torch.distributed.fsdp import FlatParameter diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 3537cad330..b8eb87ffd3 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -33,6 +33,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + +- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) + + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 38ed803235..f96b90f7b8 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -21,7 +21,11 @@ from torch.nn import Module import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing +from lightning_lite.strategies.fsdp import ( + _init_cpu_offload, + _optimizer_has_flat_params, + _setup_activation_checkpointing, +) from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -84,14 +88,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): `this tutorial `__ for more information. Arguments: - cpu_offload: - CPU offloading config. Currently, only parameter and gradient CPU - offload is supported. It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order for - params and grads to be on same device to work with optimizer. This - API is subject to change. Default is ``None`` in which case there - will be no offloading. + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch @@ -120,7 +120,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload: Optional[CPUOffload] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, @@ -141,7 +141,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): self._process_group = None self.num_nodes = 1 self._process_group_backend = process_group_backend - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False @@ -403,6 +403,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): "fsdp_native_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) cls._registered_strategies.append("fsdp_native_full_shard_offload") diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 62880ea6e5..9b066b9e5a 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -26,7 +26,7 @@ from lightning_lite.strategies.fsdp import _FSDPBackwardSyncControl from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -36,13 +36,26 @@ def test_fsdp_support(*_): @RunIf(min_torch="1.12") -def test_fsdp_custom_mixed_precision(*_): +def test_fsdp_custom_mixed_precision(): """Test that passing a custom mixed precision config works.""" config = MixedPrecision() strategy = FSDPStrategy(mixed_precision=config) assert strategy.mixed_precision_config == config +@RunIf(min_torch="1.12") +def test_fsdp_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = FSDPStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = FSDPStrategy(cpu_offload=config) + assert strategy.cpu_offload == config + + @RunIf(min_torch="1.12") def test_fsdp_setup_optimizer_validation(): """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index a9b47aad1d..d9f8f86dd0 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -17,7 +17,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import wrap @@ -306,3 +306,16 @@ def test_fully_sharded_native_activation_checkpointing(): ) as ckpt_mock: strategy._setup_model(model) ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) + + +@RunIf(min_torch="1.12") +def test_fully_sharded_native_strategy_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = DDPFullyShardedNativeStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = DDPFullyShardedNativeStrategy(cpu_offload=config) + assert strategy.cpu_offload == config