Improved model initialization API for Fabric (#17462)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
d48ec08d76
commit
4d17b5fe77
|
@ -139,6 +139,23 @@ Make your code reproducible by calling this method at the beginning of your run.
|
|||
This covers PyTorch, NumPy, and Python random number generators. In addition, Fabric takes care of properly initializing
|
||||
the seed of data loader worker processes (can be turned off by passing ``workers=False``).
|
||||
|
||||
init_module
|
||||
===========
|
||||
|
||||
Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
|
||||
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(accelerator="cuda", precision="16-true")
|
||||
|
||||
with fabric.init_module():
|
||||
# models created here will be on GPU and in float16
|
||||
model = MyModel()
|
||||
|
||||
This eliminates the waiting time to transfer the model parameters from the CPU to the device.
|
||||
For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.fabric.fabric.Fabric.init_module` method will allocate the model parameters on the meta device first before sharding.
|
||||
This makes it possible to work with models that are larger than the memory of a single device.
|
||||
|
||||
autocast
|
||||
========
|
||||
|
|
|
@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))
|
||||
|
||||
|
||||
- Added `Fabric.init_module()` context manager to instantiate large models efficiently ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
|
||||
- Added `lightning.fabric.strategies.Strategy.module_init_context()` context manager to control the model instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
* Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`, `torch.float16`, or `torch.bfloat16`) depending on the 'true' precision choice in `Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')`
|
||||
* Initializes empty weights on the meta device for FSDP models and handles the Zero stage 3 initialization for DeepSpeed before sharding
|
||||
|
||||
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from lightning_utilities.core.overrides import is_overridden
|
||||
from lightning_utilities.core.rank_zero import rank_zero_warn
|
||||
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
||||
|
@ -574,19 +574,30 @@ class Fabric:
|
|||
def sharded_model(self) -> Generator:
|
||||
"""Shard the parameters of the model instantly when instantiating the layers.
|
||||
|
||||
Use this context manager with strategies that support sharding the model parameters to save peak memory usage.
|
||||
|
||||
Example::
|
||||
|
||||
with self.sharded_model():
|
||||
model = MyModel()
|
||||
|
||||
The context manager is strategy-agnostic and for the ones that don't do sharding, it is a no-op.
|
||||
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
|
||||
"""
|
||||
if isinstance(self._strategy, _Sharded):
|
||||
with self._strategy.module_sharded_context():
|
||||
yield
|
||||
else:
|
||||
rank_zero_deprecation("`Fabric.sharded_model()` is deprecated in favor of `Fabric.init_module()`.")
|
||||
with _old_sharded_model_context(self._strategy):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def init_module(self) -> Generator:
|
||||
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
|
||||
|
||||
The parameters get created on the device and with the right data type right away without wasting memory being
|
||||
allocated unnecessarily.
|
||||
|
||||
Note:
|
||||
The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
|
||||
rank_zero_warn(
|
||||
"`Fabric.init_module()` can't place the model parameters on the device directly with PyTorch < 2.0."
|
||||
" Parameters will remain on CPU until `Fabric.setup()` is called. Upgrade to PyTorch >= 2.0 to fully"
|
||||
" utilize the features in `init_module()`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
with self._strategy.module_init_context():
|
||||
yield
|
||||
|
||||
def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimizer, Any]]) -> None:
|
||||
|
@ -763,9 +774,9 @@ class Fabric:
|
|||
def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
self._strategy.setup_environment()
|
||||
# apply sharded context to prevent OOM
|
||||
with self.sharded_model(), _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(
|
||||
BatchSampler
|
||||
):
|
||||
with _old_sharded_model_context(self._strategy), _replace_dunder_methods(
|
||||
DataLoader, "dataset"
|
||||
), _replace_dunder_methods(BatchSampler):
|
||||
return run_function(*args, **kwargs)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
|
||||
|
@ -864,3 +875,12 @@ class Fabric:
|
|||
|
||||
def _is_using_cli() -> bool:
|
||||
return bool(int(os.environ.get("LT_CLI_USED", "0")))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _old_sharded_model_context(strategy: Strategy) -> Generator:
|
||||
if isinstance(strategy, _Sharded):
|
||||
with strategy.module_sharded_context():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
|
|
@ -31,6 +31,17 @@ class DoublePrecision(Precision):
|
|||
def convert_module(self, module: Module) -> Module:
|
||||
return module.double()
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing the parameters in a module.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float64)
|
||||
yield
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
@contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Literal, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
@ -42,7 +42,15 @@ class Precision:
|
|||
"""
|
||||
return module
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
"""Instantiate the module parameters in the precision type this plugin handles.
|
||||
|
||||
This is optional and depends on the precision limitations during optimization.
|
||||
"""
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
|
||||
yield
|
||||
|
|
|
@ -338,6 +338,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
with super().module_init_context(), self.module_sharded_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def module_sharded_context(self) -> Generator[None, None, None]:
|
||||
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import functools
|
||||
import os
|
||||
from contextlib import _GeneratorContextManager, contextmanager
|
||||
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
@ -243,6 +243,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
def module_to_device(self, module: Module) -> None:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator[None, None, None]:
|
||||
device_context = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
|
||||
with device_context, self.precision.module_init_context(), self.module_sharded_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def module_sharded_context(self) -> Generator:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
@ -28,6 +28,7 @@ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
|
|||
from lightning.fabric.plugins.precision import Precision
|
||||
from lightning.fabric.strategies.launchers.launcher import _Launcher
|
||||
from lightning.fabric.utilities.apply_func import move_data_to_device
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.fabric.utilities.types import _PATH, _Stateful, Optimizable, ReduceOp
|
||||
|
||||
TBroadcast = TypeVar("TBroadcast")
|
||||
|
@ -111,6 +112,17 @@ class Strategy(ABC):
|
|||
"""
|
||||
return dataloader
|
||||
|
||||
@contextmanager
|
||||
def module_init_context(self) -> Generator:
|
||||
"""A context manager wrapping the model instantiation.
|
||||
|
||||
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
|
||||
patches to the model.
|
||||
"""
|
||||
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
|
||||
with device_context, self.precision.module_init_context():
|
||||
yield
|
||||
|
||||
def setup_module_and_optimizers(
|
||||
self, module: Module, optimizers: List[Optimizer]
|
||||
) -> Tuple[Module, List[Optimizer]]:
|
||||
|
|
|
@ -17,6 +17,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from tests_fabric.helpers.models import BoringFabric
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
||||
|
||||
class BoringDoubleModule(nn.Module):
|
||||
|
@ -50,6 +51,7 @@ class DoublePrecisionBoringFabric(BoringFabric):
|
|||
assert model.layer.weight.grad.dtype == torch.float64
|
||||
|
||||
|
||||
@RunIf(mps=False) # MPS doesn't support float64
|
||||
def test_double_precision():
|
||||
fabric = DoublePrecisionBoringFabric(devices=1, precision="64-true")
|
||||
fabric.run()
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
|
@ -18,8 +19,11 @@ import pytest
|
|||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from lightning.fabric.plugins import DoublePrecision, Precision
|
||||
from lightning.fabric.plugins.environments import LightningEnvironment
|
||||
from lightning.fabric.strategies import DDPStrategy
|
||||
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
|
||||
|
||||
|
@ -122,3 +126,27 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision):
|
|||
clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal
|
||||
fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp")
|
||||
fabric.run()
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"precision,expected_dtype",
|
||||
[
|
||||
(Precision(), torch.float32),
|
||||
(DoublePrecision(), torch.float64),
|
||||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
|
||||
expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
|
||||
strategy = DDPStrategy(
|
||||
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
|
||||
)
|
||||
assert strategy.local_rank == 1
|
||||
with strategy.module_init_context():
|
||||
module = torch.nn.Linear(2, 2)
|
||||
assert module.weight.device == module.bias.device == expected_device
|
||||
assert module.weight.dtype == module.bias.dtype == expected_dtype
|
||||
|
|
|
@ -366,7 +366,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
|
|||
|
||||
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
|
||||
|
||||
with fabric.sharded_model():
|
||||
with fabric.init_module():
|
||||
model = BoringModel()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
|
@ -390,7 +390,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
|
|||
# re-init all objects and resume
|
||||
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
|
||||
fabric.launch()
|
||||
with fabric.sharded_model():
|
||||
with fabric.init_module():
|
||||
model = BoringModel()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
|
|
|
@ -191,3 +191,37 @@ def test_compile(compile_after_setup):
|
|||
|
||||
for _ in range(3):
|
||||
model(torch.rand(2, 32, device=fabric.device)).sum().backward()
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
"precision,expected_dtype",
|
||||
[
|
||||
("32-true", torch.float32),
|
||||
("64-true", torch.float64),
|
||||
],
|
||||
)
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
fabric = Fabric(
|
||||
accelerator="cuda",
|
||||
devices=2,
|
||||
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
||||
precision=precision,
|
||||
)
|
||||
fabric.launch()
|
||||
|
||||
with fabric.init_module():
|
||||
model = torch.nn.Linear(100, 100, bias=False)
|
||||
|
||||
# The model is on the meta device until `.setup()``
|
||||
expected_device = torch.device("meta") if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
assert model.weight.device == expected_device
|
||||
assert model.weight.dtype == expected_dtype
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
model, optimizer = fabric.setup(model, optimizer)
|
||||
|
||||
# Parameters get sharded in `.setup()` and moved to the target device
|
||||
assert model.weight.device == torch.device("cuda", fabric.local_rank)
|
||||
assert model.weight.dtype == expected_dtype
|
||||
|
|
|
@ -16,7 +16,9 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lightning.fabric.plugins import DoublePrecision, Precision
|
||||
from lightning.fabric.strategies import SingleDeviceStrategy
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
|
||||
from tests_fabric.helpers.models import BoringFabric
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
@ -147,3 +149,30 @@ def test_single_device_grad_clipping(clip_type, precision):
|
|||
clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal
|
||||
fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision)
|
||||
fabric.run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
"cpu",
|
||||
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
||||
pytest.param("mps:0", marks=RunIf(mps=True)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"precision,dtype",
|
||||
[
|
||||
(Precision(), torch.float32),
|
||||
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
|
||||
],
|
||||
)
|
||||
def test_module_init_context(device, precision, dtype):
|
||||
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
||||
device = torch.device(device)
|
||||
strategy = SingleDeviceStrategy(device=device, precision=precision)
|
||||
with strategy.module_init_context():
|
||||
module = torch.nn.Linear(2, 2)
|
||||
|
||||
expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
|
||||
assert module.weight.device == module.bias.device == expected_device
|
||||
assert module.weight.dtype == module.bias.dtype == dtype
|
||||
|
|
|
@ -277,7 +277,7 @@ def test_interactive_compatible_strategy_ddp_fork(monkeypatch):
|
|||
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()])
|
||||
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", MPSAccelerator()])
|
||||
def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class):
|
||||
with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
|
||||
_Connector(accelerator=accelerator, strategy=strategy)
|
||||
|
|
|
@ -692,6 +692,7 @@ def test_launch_and_cli_not_allowed():
|
|||
fabric.launch()
|
||||
|
||||
|
||||
@RunIf(mps=False)
|
||||
@pytest.mark.parametrize("strategy", ("xla", "ddp_spawn"))
|
||||
def test_launch_and_strategies_unsupported_combinations(strategy, xla_available):
|
||||
fabric = Fabric(strategy=strategy)
|
||||
|
@ -714,16 +715,36 @@ def test_module_sharding_context():
|
|||
otherwise."""
|
||||
fabric = Fabric()
|
||||
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
|
||||
with fabric.sharded_model():
|
||||
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
|
||||
pass
|
||||
fabric._strategy.module_sharded_context.assert_not_called()
|
||||
|
||||
fabric._strategy = MagicMock(spec=_Sharded)
|
||||
with fabric.sharded_model():
|
||||
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
|
||||
pass
|
||||
fabric._strategy.module_sharded_context.assert_called_once()
|
||||
|
||||
|
||||
def test_init_module_context(monkeypatch):
|
||||
"""Test that the stratey returns the context manager for initializing the module."""
|
||||
import lightning.fabric
|
||||
|
||||
fabric = Fabric(accelerator="cpu")
|
||||
strategy = MagicMock(spec=Strategy, module_init_context=MagicMock(), root_device=torch.device("cuda", 0))
|
||||
fabric._strategy = strategy
|
||||
with fabric.init_module():
|
||||
pass
|
||||
strategy.module_init_context.assert_called_once()
|
||||
strategy.reset_mock()
|
||||
|
||||
# Pretend we are using PyTorch < 2.0
|
||||
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
|
||||
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
|
||||
with fabric.init_module():
|
||||
pass
|
||||
strategy.module_init_context.assert_called_once()
|
||||
|
||||
|
||||
def test_callbacks_input():
|
||||
"""Test the various ways in which callbacks can be registered with Fabric."""
|
||||
callback0 = Mock()
|
||||
|
|
Loading…
Reference in New Issue