From 67a14795cff77db7efec796d0ed5e9b7d0c3377d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 3 Jun 2023 04:07:02 +0200 Subject: [PATCH] Address feedback for `Fabric.init_module()` (4/4) (#17607) --- src/lightning/fabric/CHANGELOG.md | 11 +++--- src/lightning/fabric/fabric.py | 14 +++---- src/lightning/fabric/strategies/deepspeed.py | 8 +++- src/lightning/fabric/strategies/fsdp.py | 4 +- src/lightning/fabric/strategies/strategy.py | 14 +++++-- tests/tests_fabric/strategies/test_ddp.py | 4 +- .../strategies/test_fsdp_integration.py | 2 +- .../strategies/test_single_device.py | 39 +++++++++++++++++-- tests/tests_fabric/test_fabric.py | 23 ++++++----- 9 files changed, 84 insertions(+), 35 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 954fc2381c..b7123cd02b 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -33,16 +33,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424)) -- Added `Fabric.init()` context manager to instantiate tensors or models efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488)) +- Added `Fabric.init_tensor()` context manager to instantiate tensors efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488)) + - Added `Fabric.init_module()` context manager to instantiate large models efficiently directly on device, dtype, and with sharding support ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) - -- Added `lightning.fabric.plugins.Precision.init_context()` context manager to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) * Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')` + * Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding + +- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) -- Added `lightning.fabric.strategies.Strategy.init_context()` context manager to control the model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) - * Calls `lightning.fabric.plugins.Precision.init_context()` - * Initializes empty weights on the root device. - Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index aa69032c80..875f140984 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -593,20 +593,20 @@ class Fabric: yield @contextmanager - def init(self) -> Generator: - """Instantiate under this context manager to apply improvements based on your configuration. + def init_tensor(self) -> Generator: + """Tensors that you instantiate under this context manager will be created on the device right away and + have the right data type depending on the precision setting in Fabric. - The parameters get created on the device (if using PyTorch 2.0 or newer) and with the right data type right away - without wasting memory being allocated unnecessarily. + The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer. """ if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu": rank_zero_warn( - "`Fabric.init()` can't place the model parameters on the device directly" + "`Fabric.init_tensor()` can't place tensors on the device directly" " with PyTorch < 2.0. Parameters will remain on CPU until `Fabric.setup()` is called." " Upgrade to PyTorch >= 2.0 to fully utilize this feature.", category=PossibleUserWarning, ) - with self._strategy.init_context(): + with self._strategy.tensor_init_context(): yield @contextmanager @@ -624,7 +624,7 @@ class Fabric: " Upgrade to PyTorch >= 2.0 to fully utilize this feature.", category=PossibleUserWarning, ) - with self._strategy.init_context(), _old_sharded_model_context(self.strategy): + with self._strategy.module_init_context(): yield def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 1f4ce7edf4..ca84c1e9cc 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,7 +16,7 @@ import json import logging import os import platform -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from itertools import chain from pathlib import Path from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union @@ -339,6 +339,12 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): """ raise NotImplementedError(self._err_msg_joint_setup_required()) + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + precision_context = self.precision.init_context() if not self.zero_stage_3 else nullcontext() + with precision_context, self.module_sharded_context(): + yield + @contextmanager def module_sharded_context(self) -> Generator[None, None, None]: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 4139dc819b..f609046155 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -254,10 +254,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): pass @contextmanager - def init_context(self) -> Generator[None, None, None]: + def module_init_context(self) -> Generator[None, None, None]: # TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465 # is resolved. For now, the module will get moved to the device in `setup_module`. - with self.precision.init_context(): + with self.precision.init_context(), self.module_sharded_context(): yield @contextmanager diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 45e4137832..d8c910ad0f 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -114,14 +114,20 @@ class Strategy(ABC): return dataloader @contextmanager - def init_context(self) -> Generator: - """A context manager for improved tensor and module instantiation. + def tensor_init_context(self) -> Generator: + """Controls how tensors get created (device, dtype).""" + device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext() + with device_context, self.precision.init_context(): + yield + + @contextmanager + def module_init_context(self) -> Generator: + """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model. """ - device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext() - with device_context, self.precision.init_context(): + with self.tensor_init_context(): yield def setup_module_and_optimizers( diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index b7bc6f982f..42fd89ee9d 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -140,7 +140,7 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision): ], ) @mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) -def test_init_context(precision, expected_dtype): +def test_module_init_context(precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)] expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") @@ -149,7 +149,7 @@ def test_init_context(precision, expected_dtype): parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment() ) assert strategy.local_rank == 1 - with strategy.init_context(): + with strategy.module_init_context(): module = torch.nn.Linear(2, 2) assert module.weight.device == module.bias.device == expected_device assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index d208512c71..8c3800bfd7 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -284,7 +284,7 @@ def test_compile(compile_after_setup): ("64-true", torch.float64), ], ) -def test_init_context(precision, expected_dtype): +def test_module_init_context(precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" fabric = Fabric( accelerator="cuda", diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 1a26acade7..0d8ae3efcf 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -168,13 +168,46 @@ def test_single_device_grad_clipping(clip_type, precision): pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), ], ) -def test_init_context(device, precision, dtype): - """Test that the module under the init-context gets moved to the right device and dtype.""" +def test_module_init_context(device, precision, dtype): + """Test that the module under the init-module-context gets moved to the right device and dtype.""" device = torch.device(device) strategy = SingleDeviceStrategy(device=device, precision=precision) - with strategy.init_context(): + with strategy.module_init_context(): module = torch.nn.Linear(2, 2) expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") assert module.weight.device == module.bias.device == expected_device assert module.weight.dtype == module.bias.dtype == dtype + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +@pytest.mark.parametrize( + ("precision", "dtype"), + [ + (Precision(), torch.float32), + (HalfPrecision("16-true"), torch.float16), + pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)), + pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), + ], +) +def test_tensor_init_context(device, precision, dtype): + """Test that tensors under the init-tensor-context get moved to the right device and dtype.""" + device = torch.device(device) + strategy = SingleDeviceStrategy(device=device, precision=precision) + with strategy.tensor_init_context(): + tensor0 = torch.tensor(42.0) + tensor1 = torch.tensor(42) + tensor2 = torch.tensor(42.0, dtype=torch.half) + + expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") + assert tensor0.device == tensor1.device == tensor2.device == expected_device + assert tensor0.dtype == dtype + assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes + assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index e60bb31231..9cba2a6b14 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -735,35 +735,40 @@ def test_init_module_context(monkeypatch): fabric = Fabric(accelerator="cpu") strategy = SingleDeviceStrategy(device=torch.device("cuda")) - strategy.init_context = Mock(wraps=strategy.init_context) + strategy.module_init_context = Mock(wraps=strategy.module_init_context) fabric._strategy = strategy with fabric.init_module(): pass - strategy.init_context.assert_called_once() - strategy.init_context.reset_mock() + strategy.module_init_context.assert_called_once() + strategy.module_init_context.reset_mock() # Pretend we are using PyTorch < 2.0 monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False) with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117 with fabric.init_module(): pass - strategy.init_context.assert_called_once() + strategy.module_init_context.assert_called_once() -def test_init_context(monkeypatch): - """Test that `.init()` warns if using PyTorch < 2.0.""" - # TODO(awaelchli): Extend the test once `Fabric.init()` finalized +def test_init_tensor_context(monkeypatch): + """Test that `.init_tensor()` warns if using PyTorch < 2.0.""" import lightning.fabric fabric = Fabric(accelerator="cpu") strategy = SingleDeviceStrategy(device=torch.device("cuda")) + strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context) fabric._strategy = strategy + with fabric.init_tensor(): + pass + strategy.tensor_init_context.assert_called_once() + strategy.tensor_init_context.reset_mock() # Pretend we are using PyTorch < 2.0 monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False) - with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117 - with fabric.init(): + with pytest.warns(PossibleUserWarning, match="can't place tensors on the device directly"): # noqa: SIM117 + with fabric.init_tensor(): pass + strategy.tensor_init_context.assert_called_once() def test_callbacks_input():