Improved model initialization API for Fabric (#17462)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-04-26 17:25:33 +02:00 committed by GitHub
parent d48ec08d76
commit 4d17b5fe77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 224 additions and 25 deletions

View File

@ -139,6 +139,23 @@ Make your code reproducible by calling this method at the beginning of your run.
This covers PyTorch, NumPy, and Python random number generators. In addition, Fabric takes care of properly initializing
the seed of data loader worker processes (can be turned off by passing ``workers=False``).
init_module
===========
Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.
.. code-block:: python
fabric = Fabric(accelerator="cuda", precision="16-true")
with fabric.init_module():
# models created here will be on GPU and in float16
model = MyModel()
This eliminates the waiting time to transfer the model parameters from the CPU to the device.
For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.fabric.fabric.Fabric.init_module` method will allocate the model parameters on the meta device first before sharding.
This makes it possible to work with models that are larger than the memory of a single device.
autocast
========

View File

@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))
- Added `Fabric.init_module()` context manager to instantiate large models efficiently ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
- Added `lightning.fabric.strategies.Strategy.module_init_context()` context manager to control the model instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
* Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')`
* Initializes empty weights on the meta device for FSDP models and handles the Zero stage 3 initialization for DeepSpeed before sharding
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
@ -574,19 +574,30 @@ class Fabric:
def sharded_model(self) -> Generator:
"""Shard the parameters of the model instantly when instantiating the layers.
Use this context manager with strategies that support sharding the model parameters to save peak memory usage.
Example::
with self.sharded_model():
model = MyModel()
The context manager is strategy-agnostic and for the ones that don't do sharding, it is a no-op.
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
"""
if isinstance(self._strategy, _Sharded):
with self._strategy.module_sharded_context():
yield
else:
rank_zero_deprecation("`Fabric.sharded_model()` is deprecated in favor of `Fabric.init_module()`.")
with _old_sharded_model_context(self._strategy):
yield
@contextmanager
def init_module(self) -> Generator:
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
The parameters get created on the device and with the right data type right away without wasting memory being
allocated unnecessarily.
Note:
The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
"""
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
rank_zero_warn(
"`Fabric.init_module()` can't place the model parameters on the device directly with PyTorch < 2.0."
" Parameters will remain on CPU until `Fabric.setup()` is called. Upgrade to PyTorch >= 2.0 to fully"
" utilize the features in `init_module()`.",
category=PossibleUserWarning,
)
with self._strategy.module_init_context():
yield
def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None:
@ -763,9 +774,9 @@ class Fabric:
def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
# apply sharded context to prevent OOM
with self.sharded_model(), _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(
BatchSampler
):
with _old_sharded_model_context(self._strategy), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_function(*args, **kwargs)
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
@ -864,3 +875,12 @@ class Fabric:
def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))
@contextmanager
def _old_sharded_model_context(strategy: Strategy) -> Generator:
if isinstance(strategy, _Sharded):
with strategy.module_sharded_context():
yield
else:
yield

View File

@ -31,6 +31,17 @@ class DoublePrecision(Precision):
def convert_module(self, module: Module) -> Module:
return module.double()
@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing the parameters in a module.
See: :meth:`torch.set_default_tensor_type`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
yield
torch.set_default_dtype(default_dtype)
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from contextlib import contextmanager
from typing import Any, Dict, Generator, Literal, Optional, Union
from torch import Tensor
@ -42,7 +42,15 @@ class Precision:
"""
return module
@contextlib.contextmanager
@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
"""Instantiate the module parameters in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
yield
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

View File

@ -338,6 +338,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
with super().module_init_context(), self.module_sharded_context():
yield
@contextmanager
def module_sharded_context(self) -> Generator[None, None, None]:
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context

View File

@ -13,7 +13,7 @@
# limitations under the License.
import functools
import os
from contextlib import _GeneratorContextManager, contextmanager
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
@ -243,6 +243,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
def module_to_device(self, module: Module) -> None:
pass
@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
device_context = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
with device_context, self.precision.module_init_context(), self.module_sharded_context():
yield
@contextmanager
def module_sharded_context(self) -> Generator:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
import torch
@ -28,6 +28,7 @@ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH, _Stateful, Optimizable, ReduceOp
TBroadcast = TypeVar("TBroadcast")
@ -111,6 +112,17 @@ class Strategy(ABC):
"""
return dataloader
@contextmanager
def module_init_context(self) -> Generator:
"""A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
patches to the model.
"""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
with device_context, self.precision.module_init_context():
yield
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:

View File

@ -17,6 +17,7 @@ import torch
import torch.nn as nn
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
class BoringDoubleModule(nn.Module):
@ -50,6 +51,7 @@ class DoublePrecisionBoringFabric(BoringFabric):
assert model.layer.weight.grad.dtype == torch.float64
@RunIf(mps=False) # MPS doesn't support float64
def test_double_precision():
fabric = DoublePrecisionBoringFabric(devices=1, precision="64-true")
fabric.run()

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import MagicMock, Mock
@ -18,8 +19,11 @@ import pytest
import torch
from torch.nn.parallel import DistributedDataParallel
from lightning.fabric.plugins import DoublePrecision, Precision
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
@ -122,3 +126,27 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision):
clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal
fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp")
fabric.run()
@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize(
"precision,expected_dtype",
[
(Precision(), torch.float32),
(DoublePrecision(), torch.float64),
],
)
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
strategy = DDPStrategy(
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
)
assert strategy.local_rank == 1
with strategy.module_init_context():
module = torch.nn.Linear(2, 2)
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == expected_dtype

View File

@ -366,7 +366,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
with fabric.sharded_model():
with fabric.init_module():
model = BoringModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
@ -390,7 +390,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
# re-init all objects and resume
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
fabric.launch()
with fabric.sharded_model():
with fabric.init_module():
model = BoringModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

View File

@ -191,3 +191,37 @@ def test_compile(compile_after_setup):
for _ in range(3):
model(torch.rand(2, 32, device=fabric.device)).sum().backward()
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
"precision,expected_dtype",
[
("32-true", torch.float32),
("64-true", torch.float64),
],
)
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
fabric = Fabric(
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
)
fabric.launch()
with fabric.init_module():
model = torch.nn.Linear(100, 100, bias=False)
# The model is on the meta device until `.setup()``
expected_device = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype

View File

@ -16,7 +16,9 @@ from unittest.mock import Mock
import pytest
import torch
from lightning.fabric.plugins import DoublePrecision, Precision
from lightning.fabric.strategies import SingleDeviceStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
@ -147,3 +149,30 @@ def test_single_device_grad_clipping(clip_type, precision):
clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal
fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision)
fabric.run()
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
"precision,dtype",
[
(Precision(), torch.float32),
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
def test_module_init_context(device, precision, dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision=precision)
with strategy.module_init_context():
module = torch.nn.Linear(2, 2)
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == dtype

View File

@ -277,7 +277,7 @@ def test_interactive_compatible_strategy_ddp_fork(monkeypatch):
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
),
)
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()])
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", MPSAccelerator()])
def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class):
with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
_Connector(accelerator=accelerator, strategy=strategy)

View File

@ -692,6 +692,7 @@ def test_launch_and_cli_not_allowed():
fabric.launch()
@RunIf(mps=False)
@pytest.mark.parametrize("strategy", ("xla", "ddp_spawn"))
def test_launch_and_strategies_unsupported_combinations(strategy, xla_available):
fabric = Fabric(strategy=strategy)
@ -714,16 +715,36 @@ def test_module_sharding_context():
otherwise."""
fabric = Fabric()
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
with fabric.sharded_model():
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
pass
fabric._strategy.module_sharded_context.assert_not_called()
fabric._strategy = MagicMock(spec=_Sharded)
with fabric.sharded_model():
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
pass
fabric._strategy.module_sharded_context.assert_called_once()
def test_init_module_context(monkeypatch):
"""Test that the stratey returns the context manager for initializing the module."""
import lightning.fabric
fabric = Fabric(accelerator="cpu")
strategy = MagicMock(spec=Strategy, module_init_context=MagicMock(), root_device=torch.device("cuda", 0))
fabric._strategy = strategy
with fabric.init_module():
pass
strategy.module_init_context.assert_called_once()
strategy.reset_mock()
# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
with fabric.init_module():
pass
strategy.module_init_context.assert_called_once()
def test_callbacks_input():
"""Test the various ways in which callbacks can be registered with Fabric."""
callback0 = Mock()