Address feedback for `Fabric.init_module()` (4/4) (#17607)

This commit is contained in:
Adrian Wälchli 2023-06-03 04:07:02 +02:00 committed by GitHub
parent 255b18823e
commit 67a14795cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 84 additions and 35 deletions

View File

@ -33,16 +33,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))
- Added `Fabric.init()` context manager to instantiate tensors or models efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488))
- Added `Fabric.init_tensor()` context manager to instantiate tensors efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488))
- Added `Fabric.init_module()` context manager to instantiate large models efficiently directly on device, dtype, and with sharding support ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
- Added `lightning.fabric.plugins.Precision.init_context()` context manager to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
* Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')`
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
- Added `lightning.fabric.strategies.Strategy.init_context()` context manager to control the model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
* Calls `lightning.fabric.plugins.Precision.init_context()`
* Initializes empty weights on the root device.
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))

View File

@ -593,20 +593,20 @@ class Fabric:
yield
@contextmanager
def init(self) -> Generator:
"""Instantiate under this context manager to apply improvements based on your configuration.
def init_tensor(self) -> Generator:
"""Tensors that you instantiate under this context manager will be created on the device right away and
have the right data type depending on the precision setting in Fabric.
The parameters get created on the device (if using PyTorch 2.0 or newer) and with the right data type right away
without wasting memory being allocated unnecessarily.
The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
"""
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
rank_zero_warn(
"`Fabric.init()` can't place the model parameters on the device directly"
"`Fabric.init_tensor()` can't place tensors on the device directly"
" with PyTorch < 2.0. Parameters will remain on CPU until `Fabric.setup()` is called."
" Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
category=PossibleUserWarning,
)
with self._strategy.init_context():
with self._strategy.tensor_init_context():
yield
@contextmanager
@ -624,7 +624,7 @@ class Fabric:
" Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
category=PossibleUserWarning,
)
with self._strategy.init_context(), _old_sharded_model_context(self.strategy):
with self._strategy.module_init_context():
yield
def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None:

View File

@ -16,7 +16,7 @@ import json
import logging
import os
import platform
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
@ -339,6 +339,12 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
precision_context = self.precision.init_context() if not self.zero_stage_3 else nullcontext()
with precision_context, self.module_sharded_context():
yield
@contextmanager
def module_sharded_context(self) -> Generator[None, None, None]:
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context

View File

@ -254,10 +254,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
pass
@contextmanager
def init_context(self) -> Generator[None, None, None]:
def module_init_context(self) -> Generator[None, None, None]:
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
# is resolved. For now, the module will get moved to the device in `setup_module`.
with self.precision.init_context():
with self.precision.init_context(), self.module_sharded_context():
yield
@contextmanager

View File

@ -114,14 +114,20 @@ class Strategy(ABC):
return dataloader
@contextmanager
def init_context(self) -> Generator:
"""A context manager for improved tensor and module instantiation.
def tensor_init_context(self) -> Generator:
"""Controls how tensors get created (device, dtype)."""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
with device_context, self.precision.init_context():
yield
@contextmanager
def module_init_context(self) -> Generator:
"""A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
patches to the model.
"""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
with device_context, self.precision.init_context():
with self.tensor_init_context():
yield
def setup_module_and_optimizers(

View File

@ -140,7 +140,7 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision):
],
)
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
def test_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
@ -149,7 +149,7 @@ def test_init_context(precision, expected_dtype):
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
)
assert strategy.local_rank == 1
with strategy.init_context():
with strategy.module_init_context():
module = torch.nn.Linear(2, 2)
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == expected_dtype

View File

@ -284,7 +284,7 @@ def test_compile(compile_after_setup):
("64-true", torch.float64),
],
)
def test_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
fabric = Fabric(
accelerator="cuda",

View File

@ -168,13 +168,46 @@ def test_single_device_grad_clipping(clip_type, precision):
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
def test_init_context(device, precision, dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
def test_module_init_context(device, precision, dtype):
"""Test that the module under the init-module-context gets moved to the right device and dtype."""
device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision=precision)
with strategy.init_context():
with strategy.module_init_context():
module = torch.nn.Linear(2, 2)
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == dtype
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
("precision", "dtype"),
[
(Precision(), torch.float32),
(HalfPrecision("16-true"), torch.float16),
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
def test_tensor_init_context(device, precision, dtype):
"""Test that tensors under the init-tensor-context get moved to the right device and dtype."""
device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision=precision)
with strategy.tensor_init_context():
tensor0 = torch.tensor(42.0)
tensor1 = torch.tensor(42)
tensor2 = torch.tensor(42.0, dtype=torch.half)
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
assert tensor0.device == tensor1.device == tensor2.device == expected_device
assert tensor0.dtype == dtype
assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes
assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment

View File

@ -735,35 +735,40 @@ def test_init_module_context(monkeypatch):
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.init_context = Mock(wraps=strategy.init_context)
strategy.module_init_context = Mock(wraps=strategy.module_init_context)
fabric._strategy = strategy
with fabric.init_module():
pass
strategy.init_context.assert_called_once()
strategy.init_context.reset_mock()
strategy.module_init_context.assert_called_once()
strategy.module_init_context.reset_mock()
# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
with fabric.init_module():
pass
strategy.init_context.assert_called_once()
strategy.module_init_context.assert_called_once()
def test_init_context(monkeypatch):
"""Test that `.init()` warns if using PyTorch < 2.0."""
# TODO(awaelchli): Extend the test once `Fabric.init()` finalized
def test_init_tensor_context(monkeypatch):
"""Test that `.init_tensor()` warns if using PyTorch < 2.0."""
import lightning.fabric
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
fabric._strategy = strategy
with fabric.init_tensor():
pass
strategy.tensor_init_context.assert_called_once()
strategy.tensor_init_context.reset_mock()
# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
with fabric.init():
with pytest.warns(PossibleUserWarning, match="can't place tensors on the device directly"): # noqa: SIM117
with fabric.init_tensor():
pass
strategy.tensor_init_context.assert_called_once()
def test_callbacks_input():