(3/n) Support 2D Parallelism - Efficient loading of full-state checkpoints (#19870)
* memory-optimized loading of full checkpoints into dist model * simplify * handle buffers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * handle strict loading, buffers, and add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9455871c93
commit
cd8acc26c3
|
@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
|
||||
|
||||
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))
|
||||
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -11,11 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import shutil
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
|
@ -429,7 +430,6 @@ def _load_checkpoint(
|
|||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
set_model_state_dict,
|
||||
set_optimizer_state_dict,
|
||||
)
|
||||
|
||||
|
@ -484,13 +484,8 @@ def _load_checkpoint(
|
|||
if not _TORCH_GREATER_EQUAL_2_4:
|
||||
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
|
||||
|
||||
state_dict_options = StateDictOptions(
|
||||
broadcast_from_rank0=True, # type: ignore[call-arg]
|
||||
full_state_dict=True,
|
||||
strict=strict,
|
||||
)
|
||||
checkpoint = torch.load(path, mmap=True, map_location="cpu")
|
||||
set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options)
|
||||
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
|
||||
|
||||
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
|
||||
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
|
||||
|
@ -525,7 +520,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int
|
|||
_load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict)
|
||||
|
||||
|
||||
def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
|
||||
def _load_raw_module_state(
|
||||
state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True
|
||||
) -> None:
|
||||
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
|
@ -535,11 +532,39 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz
|
|||
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
|
||||
|
||||
state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg]
|
||||
set_model_state_dict(module, state_dict, options=state_dict_options)
|
||||
state_dict_options = StateDictOptions(
|
||||
broadcast_from_rank0=True, # type: ignore[call-arg]
|
||||
full_state_dict=True,
|
||||
strict=strict, # gets ignored at the moment
|
||||
)
|
||||
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
|
||||
full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
|
||||
if full_param_name not in state_dict:
|
||||
# Note: PyTorch does not currently respect the `strict` setting in state_dict_options!
|
||||
if not strict:
|
||||
continue
|
||||
raise KeyError(
|
||||
f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint."
|
||||
" To disable strict loading, set `strict=False`."
|
||||
)
|
||||
local_state_dict = {param_name: state_dict[full_param_name]}
|
||||
set_model_state_dict(submodule, local_state_dict, options=state_dict_options)
|
||||
|
||||
elif isinstance(module, FSDP):
|
||||
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
|
||||
module.load_state_dict(state_dict, strict=strict)
|
||||
else:
|
||||
module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def _named_parameters_and_buffers_to_load(module: Module) -> Generator:
|
||||
"""Returns parameters and buffers, with non-persistent buffers excluded."""
|
||||
for param_name, param in itertools.chain(
|
||||
module.named_buffers(recurse=False),
|
||||
module.named_parameters(recurse=False),
|
||||
):
|
||||
if param_name in module._non_persistent_buffers_set:
|
||||
continue
|
||||
yield param_name, param
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
|
@ -20,7 +21,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from lightning.fabric import Fabric
|
||||
from lightning.fabric.strategies import ModelParallelStrategy
|
||||
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
|
||||
from lightning.fabric.utilities.load import _load_distributed_checkpoint
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
|
@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
|
|||
|
||||
state = {"model": model, "steps": 1}
|
||||
fabric.load(checkpoint_path_full, state)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
||||
def test_load_raw_module_state():
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
||||
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter = nn.Parameter(torch.rand(2, 2))
|
||||
self.layer1 = nn.Linear(4, 4)
|
||||
self.layer2 = nn.Linear(4, 4)
|
||||
self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
|
||||
self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)
|
||||
|
||||
fabric = Fabric(accelerator="cuda", devices=2)
|
||||
fabric.launch()
|
||||
fabric.seed_everything(0)
|
||||
|
||||
with fabric.init_module():
|
||||
model = CustomModel()
|
||||
|
||||
state_dict = deepcopy(model.state_dict())
|
||||
|
||||
with fabric.init_module():
|
||||
model = CustomModel()
|
||||
|
||||
device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
|
||||
plan = {"layer1": ColwiseParallel()}
|
||||
parallelize_module(model, device_mesh, plan)
|
||||
_load_raw_module_state(state_dict, model, strict=True)
|
||||
|
||||
assert torch.equal(model.parameter, state_dict["parameter"])
|
||||
assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
|
||||
assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
|
||||
assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])
|
||||
|
||||
state_dict.pop("parameter")
|
||||
with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
|
||||
_load_raw_module_state(state_dict, model, strict=True)
|
||||
|
||||
_load_raw_module_state(state_dict, model, strict=False)
|
||||
|
|
Loading…
Reference in New Issue