diff --git a/docs/source-fabric/api/fabric_methods.rst b/docs/source-fabric/api/fabric_methods.rst index e1731e075e..75dfede327 100644 --- a/docs/source-fabric/api/fabric_methods.rst +++ b/docs/source-fabric/api/fabric_methods.rst @@ -139,6 +139,23 @@ Make your code reproducible by calling this method at the beginning of your run. This covers PyTorch, NumPy, and Python random number generators. In addition, Fabric takes care of properly initializing the seed of data loader worker processes (can be turned off by passing ``workers=False``). +init_module +=========== + +Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default. +To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code. + +.. code-block:: python + + fabric = Fabric(accelerator="cuda", precision="16-true") + + with fabric.init_module(): + # models created here will be on GPU and in float16 + model = MyModel() + +This eliminates the waiting time to transfer the model parameters from the CPU to the device. +For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.fabric.fabric.Fabric.init_module` method will allocate the model parameters on the meta device first before sharding. +This makes it possible to work with models that are larger than the memory of a single device. autocast ======== diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 471e50a0dd..62b0c65866 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424)) +- Added `Fabric.init_module()` context manager to instantiate large models efficiently ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) + +- Added `lightning.fabric.strategies.Strategy.module_init_context()` context manager to control the model instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) + * Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')` + * Initializes empty weights on the meta device for FSDP models and handles the Zero stage 3 initialization for DeepSpeed before sharding + - Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index f076f4e50c..c02fce590c 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.overrides import is_overridden -from lightning_utilities.core.rank_zero import rank_zero_warn +from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn from torch import Tensor from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler @@ -574,19 +574,30 @@ class Fabric: def sharded_model(self) -> Generator: """Shard the parameters of the model instantly when instantiating the layers. - Use this context manager with strategies that support sharding the model parameters to save peak memory usage. - - Example:: - - with self.sharded_model(): - model = MyModel() - - The context manager is strategy-agnostic and for the ones that don't do sharding, it is a no-op. + .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. """ - if isinstance(self._strategy, _Sharded): - with self._strategy.module_sharded_context(): - yield - else: + rank_zero_deprecation("`Fabric.sharded_model()` is deprecated in favor of `Fabric.init_module()`.") + with _old_sharded_model_context(self._strategy): + yield + + @contextmanager + def init_module(self) -> Generator: + """Instantiate the model and its parameters under this context manager to reduce peak memory usage. + + The parameters get created on the device and with the right data type right away without wasting memory being + allocated unnecessarily. + + Note: + The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer. + """ + if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu": + rank_zero_warn( + "`Fabric.init_module()` can't place the model parameters on the device directly with PyTorch < 2.0." + " Parameters will remain on CPU until `Fabric.setup()` is called. Upgrade to PyTorch >= 2.0 to fully" + " utilize the features in `init_module()`.", + category=PossibleUserWarning, + ) + with self._strategy.module_init_context(): yield def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None: @@ -763,9 +774,9 @@ class Fabric: def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() # apply sharded context to prevent OOM - with self.sharded_model(), _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods( - BatchSampler - ): + with _old_sharded_model_context(self._strategy), _replace_dunder_methods( + DataLoader, "dataset" + ), _replace_dunder_methods(BatchSampler): return run_function(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: @@ -864,3 +875,12 @@ class Fabric: def _is_using_cli() -> bool: return bool(int(os.environ.get("LT_CLI_USED", "0"))) + + +@contextmanager +def _old_sharded_model_context(strategy: Strategy) -> Generator: + if isinstance(strategy, _Sharded): + with strategy.module_sharded_context(): + yield + else: + yield diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 16906613f6..fd495ef5c0 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -31,6 +31,17 @@ class DoublePrecision(Precision): def convert_module(self, module: Module) -> Module: return module.double() + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type when initializing the parameters in a module. + + See: :meth:`torch.set_default_tensor_type` + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + yield + torch.set_default_dtype(default_dtype) + @contextmanager def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 851c9ee000..dddbfc20c8 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib +from contextlib import contextmanager from typing import Any, Dict, Generator, Literal, Optional, Union from torch import Tensor @@ -42,7 +42,15 @@ class Precision: """ return module - @contextlib.contextmanager + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + """Instantiate the module parameters in the precision type this plugin handles. + + This is optional and depends on the precision limitations during optimization. + """ + yield + + @contextmanager def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 467b339db5..d1133812c0 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -338,6 +338,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): """ raise NotImplementedError(self._err_msg_joint_setup_required()) + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + with super().module_init_context(), self.module_sharded_context(): + yield + @contextmanager def module_sharded_context(self) -> Generator[None, None, None]: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 116c0c747b..a3171b2f8e 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import os -from contextlib import _GeneratorContextManager, contextmanager +from contextlib import _GeneratorContextManager, contextmanager, nullcontext from datetime import timedelta from pathlib import Path from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union @@ -243,6 +243,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def module_to_device(self, module: Module) -> None: pass + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + device_context = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else nullcontext() + with device_context, self.precision.module_init_context(), self.module_sharded_context(): + yield + @contextmanager def module_sharded_context(self) -> Generator: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index afffaf21d0..95caba8a81 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union import torch @@ -28,6 +28,7 @@ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.launchers.launcher import _Launcher from lightning.fabric.utilities.apply_func import move_data_to_device +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.types import _PATH, _Stateful, Optimizable, ReduceOp TBroadcast = TypeVar("TBroadcast") @@ -111,6 +112,17 @@ class Strategy(ABC): """ return dataloader + @contextmanager + def module_init_context(self) -> Generator: + """A context manager wrapping the model instantiation. + + Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other + patches to the model. + """ + device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext() + with device_context, self.precision.module_init_context(): + yield + def setup_module_and_optimizers( self, module: Module, optimizers: List[Optimizer] ) -> Tuple[Module, List[Optimizer]]: diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index cfd0918af8..76519ca8e6 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn from tests_fabric.helpers.models import BoringFabric +from tests_fabric.helpers.runif import RunIf class BoringDoubleModule(nn.Module): @@ -50,6 +51,7 @@ class DoublePrecisionBoringFabric(BoringFabric): assert model.layer.weight.grad.dtype == torch.float64 +@RunIf(mps=False) # MPS doesn't support float64 def test_double_precision(): fabric = DoublePrecisionBoringFabric(devices=1, precision="64-true") fabric.run() diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 4ea9a151ea..7e0f0ee8e2 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from unittest import mock from unittest.mock import MagicMock, Mock @@ -18,8 +19,11 @@ import pytest import torch from torch.nn.parallel import DistributedDataParallel +from lightning.fabric.plugins import DoublePrecision, Precision +from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal @@ -122,3 +126,27 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision): clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp") fabric.run() + + +@RunIf(min_cuda_gpus=2) +@pytest.mark.parametrize( + "precision,expected_dtype", + [ + (Precision(), torch.float32), + (DoublePrecision(), torch.float64), + ], +) +@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) +def test_module_init_context(precision, expected_dtype): + """Test that the module under the init-context gets moved to the right device and dtype.""" + parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)] + expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") + + strategy = DDPStrategy( + parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment() + ) + assert strategy.local_rank == 1 + with strategy.module_init_context(): + module = torch.nn.Linear(2, 2) + assert module.weight.device == module.bias.device == expected_device + assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 3c4d1d24a4..d509a12bef 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -366,7 +366,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path): checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint") - with fabric.sharded_model(): + with fabric.init_module(): model = BoringModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) @@ -390,7 +390,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path): # re-init all objects and resume fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16") fabric.launch() - with fabric.sharded_model(): + with fabric.init_module(): model = BoringModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 14c1016527..024475940f 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -191,3 +191,37 @@ def test_compile(compile_after_setup): for _ in range(3): model(torch.rand(2, 32, device=fabric.device)).sum().backward() + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize( + "precision,expected_dtype", + [ + ("32-true", torch.float32), + ("64-true", torch.float64), + ], +) +def test_module_init_context(precision, expected_dtype): + """Test that the module under the init-context gets moved to the right device and dtype.""" + fabric = Fabric( + accelerator="cuda", + devices=2, + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), + precision=precision, + ) + fabric.launch() + + with fabric.init_module(): + model = torch.nn.Linear(100, 100, bias=False) + + # The model is on the meta device until `.setup()`` + expected_device = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") + assert model.weight.device == expected_device + assert model.weight.dtype == expected_dtype + + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + # Parameters get sharded in `.setup()` and moved to the target device + assert model.weight.device == torch.device("cuda", fabric.local_rank) + assert model.weight.dtype == expected_dtype diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 005de6bbd6..9271fb590b 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -16,7 +16,9 @@ from unittest.mock import Mock import pytest import torch +from lightning.fabric.plugins import DoublePrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from tests_fabric.helpers.models import BoringFabric from tests_fabric.helpers.runif import RunIf @@ -147,3 +149,30 @@ def test_single_device_grad_clipping(clip_type, precision): clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision) fabric.run() + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +@pytest.mark.parametrize( + "precision,dtype", + [ + (Precision(), torch.float32), + pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), + ], +) +def test_module_init_context(device, precision, dtype): + """Test that the module under the init-context gets moved to the right device and dtype.""" + device = torch.device(device) + strategy = SingleDeviceStrategy(device=device, precision=precision) + with strategy.module_init_context(): + module = torch.nn.Linear(2, 2) + + expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") + assert module.weight.device == module.bias.device == expected_device + assert module.weight.dtype == module.bias.dtype == dtype diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 451cb94097..b6b9da4947 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -277,7 +277,7 @@ def test_interactive_compatible_strategy_ddp_fork(monkeypatch): pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), ), ) -@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()]) +@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", MPSAccelerator()]) def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class): with pytest.raises(ValueError, match="strategies from the DDP family are not supported"): _Connector(accelerator=accelerator, strategy=strategy) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 1008ba3ab9..40cf0df486 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -692,6 +692,7 @@ def test_launch_and_cli_not_allowed(): fabric.launch() +@RunIf(mps=False) @pytest.mark.parametrize("strategy", ("xla", "ddp_spawn")) def test_launch_and_strategies_unsupported_combinations(strategy, xla_available): fabric = Fabric(strategy=strategy) @@ -714,16 +715,36 @@ def test_module_sharding_context(): otherwise.""" fabric = Fabric() fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock()) - with fabric.sharded_model(): + with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model(): pass fabric._strategy.module_sharded_context.assert_not_called() fabric._strategy = MagicMock(spec=_Sharded) - with fabric.sharded_model(): + with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model(): pass fabric._strategy.module_sharded_context.assert_called_once() +def test_init_module_context(monkeypatch): + """Test that the stratey returns the context manager for initializing the module.""" + import lightning.fabric + + fabric = Fabric(accelerator="cpu") + strategy = MagicMock(spec=Strategy, module_init_context=MagicMock(), root_device=torch.device("cuda", 0)) + fabric._strategy = strategy + with fabric.init_module(): + pass + strategy.module_init_context.assert_called_once() + strategy.reset_mock() + + # Pretend we are using PyTorch < 2.0 + monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False) + with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117 + with fabric.init_module(): + pass + strategy.module_init_context.assert_called_once() + + def test_callbacks_input(): """Test the various ways in which callbacks can be registered with Fabric.""" callback0 = Mock()