Address feedback for `Fabric.init_module()` (4/4) (#17607)
This commit is contained in:
parent
255b18823e
commit
67a14795cf
|
@ -33,16 +33,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))
|
||||
|
||||
|
||||
- Added `Fabric.init()` context manager to instantiate tensors or models efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488))
|
||||
- Added `Fabric.init_tensor()` context manager to instantiate tensors efficiently directly on device and dtype ([#17488](https://github.com/Lightning-AI/lightning/pull/17488))
|
||||
|
||||
|
||||
- Added `Fabric.init_module()` context manager to instantiate large models efficiently directly on device, dtype, and with sharding support ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
|
||||
- Added `lightning.fabric.plugins.Precision.init_context()` context manager to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
* Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')`
|
||||
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
|
||||
|
||||
- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
|
||||
- Added `lightning.fabric.strategies.Strategy.init_context()` context manager to control the model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
* Calls `lightning.fabric.plugins.Precision.init_context()`
|
||||
* Initializes empty weights on the root device.
|
||||
|
||||
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))
|
||||
|
||||
|
|
|
@ -593,20 +593,20 @@ class Fabric:
|
|||
yield
|
||||
|
||||
@contextmanager
|
||||
def init(self) -> Generator:
|
||||
"""Instantiate under this context manager to apply improvements based on your configuration.
|
||||
def init_tensor(self) -> Generator:
|
||||
"""Tensors that you instantiate under this context manager will be created on the device right away and
|
||||
have the right data type depending on the precision setting in Fabric.
|
||||
|
||||
The parameters get created on the device (if using PyTorch 2.0 or newer) and with the right data type right away
|
||||
without wasting memory being allocated unnecessarily.
|
||||
The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
|
||||
rank_zero_warn(
|
||||
"`Fabric.init()` can't place the model parameters on the device directly"
|
||||
"`Fabric.init_tensor()` can't place tensors on the device directly"
|
||||
" with PyTorch < 2.0. Parameters will remain on CPU until `Fabric.setup()` is called."
|
||||
" Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
with self._strategy.init_context():
|
||||
with self._strategy.tensor_init_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
|
@ -624,7 +624,7 @@ class Fabric:
|
|||
" Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
with self._strategy.init_context(), _old_sharded_model_context(self.strategy):
|
||||
with self._strategy.module_init_context():
|
||||
yield
|
||||
|
||||
def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
@ -339,6 +339,12 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
precision_context = self.precision.init_context() if not self.zero_stage_3 else nullcontext()
|
||||
with precision_context, self.module_sharded_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def module_sharded_context(self) -> Generator[None, None, None]:
|
||||
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
|
||||
|
|
|
@ -254,10 +254,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
pass
|
||||
|
||||
@contextmanager
|
||||
def init_context(self) -> Generator[None, None, None]:
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
|
||||
# is resolved. For now, the module will get moved to the device in `setup_module`.
|
||||
with self.precision.init_context():
|
||||
with self.precision.init_context(), self.module_sharded_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -114,14 +114,20 @@ class Strategy(ABC):
|
|||
return dataloader
|
||||
|
||||
@contextmanager
|
||||
def init_context(self) -> Generator:
|
||||
"""A context manager for improved tensor and module instantiation.
|
||||
def tensor_init_context(self) -> Generator:
|
||||
"""Controls how tensors get created (device, dtype)."""
|
||||
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
|
||||
with device_context, self.precision.init_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator:
|
||||
"""A context manager wrapping the model instantiation.
|
||||
|
||||
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
|
||||
patches to the model.
|
||||
"""
|
||||
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
|
||||
with device_context, self.precision.init_context():
|
||||
with self.tensor_init_context():
|
||||
yield
|
||||
|
||||
def setup_module_and_optimizers(
|
||||
|
|
|
@ -140,7 +140,7 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision):
|
|||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
|
||||
def test_init_context(precision, expected_dtype):
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
|
||||
expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
|
@ -149,7 +149,7 @@ def test_init_context(precision, expected_dtype):
|
|||
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
|
||||
)
|
||||
assert strategy.local_rank == 1
|
||||
with strategy.init_context():
|
||||
with strategy.module_init_context():
|
||||
module = torch.nn.Linear(2, 2)
|
||||
assert module.weight.device == module.bias.device == expected_device
|
||||
assert module.weight.dtype == module.bias.dtype == expected_dtype
|
||||
|
|
|
@ -284,7 +284,7 @@ def test_compile(compile_after_setup):
|
|||
("64-true", torch.float64),
|
||||
],
|
||||
)
|
||||
def test_init_context(precision, expected_dtype):
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
fabric = Fabric(
|
||||
accelerator="cuda",
|
||||
|
|
|
@ -168,13 +168,46 @@ def test_single_device_grad_clipping(clip_type, precision):
|
|||
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
|
||||
],
|
||||
)
|
||||
def test_init_context(device, precision, dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
def test_module_init_context(device, precision, dtype):
|
||||
"""Test that the module under the init-module-context gets moved to the right device and dtype."""
|
||||
device = torch.device(device)
|
||||
strategy = SingleDeviceStrategy(device=device, precision=precision)
|
||||
with strategy.init_context():
|
||||
with strategy.module_init_context():
|
||||
module = torch.nn.Linear(2, 2)
|
||||
|
||||
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
assert module.weight.device == module.bias.device == expected_device
|
||||
assert module.weight.dtype == module.bias.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
"cpu",
|
||||
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
||||
pytest.param("mps:0", marks=RunIf(mps=True)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "dtype"),
|
||||
[
|
||||
(Precision(), torch.float32),
|
||||
(HalfPrecision("16-true"), torch.float16),
|
||||
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
|
||||
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
|
||||
],
|
||||
)
|
||||
def test_tensor_init_context(device, precision, dtype):
|
||||
"""Test that tensors under the init-tensor-context get moved to the right device and dtype."""
|
||||
device = torch.device(device)
|
||||
strategy = SingleDeviceStrategy(device=device, precision=precision)
|
||||
with strategy.tensor_init_context():
|
||||
tensor0 = torch.tensor(42.0)
|
||||
tensor1 = torch.tensor(42)
|
||||
tensor2 = torch.tensor(42.0, dtype=torch.half)
|
||||
|
||||
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
assert tensor0.device == tensor1.device == tensor2.device == expected_device
|
||||
assert tensor0.dtype == dtype
|
||||
assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes
|
||||
assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment
|
||||
|
|
|
@ -735,35 +735,40 @@ def test_init_module_context(monkeypatch):
|
|||
|
||||
fabric = Fabric(accelerator="cpu")
|
||||
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
|
||||
strategy.init_context = Mock(wraps=strategy.init_context)
|
||||
strategy.module_init_context = Mock(wraps=strategy.module_init_context)
|
||||
fabric._strategy = strategy
|
||||
with fabric.init_module():
|
||||
pass
|
||||
strategy.init_context.assert_called_once()
|
||||
strategy.init_context.reset_mock()
|
||||
strategy.module_init_context.assert_called_once()
|
||||
strategy.module_init_context.reset_mock()
|
||||
|
||||
# Pretend we are using PyTorch < 2.0
|
||||
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
|
||||
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
|
||||
with fabric.init_module():
|
||||
pass
|
||||
strategy.init_context.assert_called_once()
|
||||
strategy.module_init_context.assert_called_once()
|
||||
|
||||
|
||||
def test_init_context(monkeypatch):
|
||||
"""Test that `.init()` warns if using PyTorch < 2.0."""
|
||||
# TODO(awaelchli): Extend the test once `Fabric.init()` finalized
|
||||
def test_init_tensor_context(monkeypatch):
|
||||
"""Test that `.init_tensor()` warns if using PyTorch < 2.0."""
|
||||
import lightning.fabric
|
||||
|
||||
fabric = Fabric(accelerator="cpu")
|
||||
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
|
||||
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
|
||||
fabric._strategy = strategy
|
||||
with fabric.init_tensor():
|
||||
pass
|
||||
strategy.tensor_init_context.assert_called_once()
|
||||
strategy.tensor_init_context.reset_mock()
|
||||
|
||||
# Pretend we are using PyTorch < 2.0
|
||||
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
|
||||
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
|
||||
with fabric.init():
|
||||
with pytest.warns(PossibleUserWarning, match="can't place tensors on the device directly"): # noqa: SIM117
|
||||
with fabric.init_tensor():
|
||||
pass
|
||||
strategy.tensor_init_context.assert_called_once()
|
||||
|
||||
|
||||
def test_callbacks_input():
|
||||
|
|
Loading…
Reference in New Issue