Fix parameter count in ModelSummary when parameters are DTensors (#20163)
This commit is contained in:
parent
3de60f4b9f
commit
345450b0c3
|
@ -14,10 +14,11 @@ import torch.nn.functional as F
|
|||
from lightning_utilities.core.imports import package_available
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, DistributedSampler, Sampler
|
||||
from typing_extensions import Self, override
|
||||
from typing_extensions import Self, TypeGuard, override
|
||||
|
||||
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
|
||||
from lightning.fabric.utilities.data import _num_cpus_available
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_info
|
||||
from lightning.fabric.utilities.types import _PATH, ReduceOp
|
||||
|
||||
|
@ -30,6 +31,8 @@ else:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
from lightning.fabric.plugins import ClusterEnvironment
|
||||
from lightning.fabric.strategies import Strategy
|
||||
|
||||
|
@ -427,3 +430,11 @@ class _InfiniteBarrier:
|
|||
self.barrier()
|
||||
if self.group is not None:
|
||||
torch.distributed.destroy_process_group(self.group)
|
||||
|
||||
|
||||
def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
|
||||
if _TORCH_GREATER_EQUAL_2_4:
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
return isinstance(tensor, DTensor)
|
||||
return False
|
||||
|
|
|
@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
|
||||
|
||||
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
|
||||
|
||||
|
||||
## [2.3.0] - 2024-06-13
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch import Tensor
|
|||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.distributed import _is_dtensor
|
||||
from lightning.pytorch.utilities.model_helpers import _ModuleMode
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache
|
||||
|
||||
|
@ -135,7 +136,7 @@ class LayerSummary:
|
|||
@property
|
||||
def num_parameters(self) -> int:
|
||||
"""Returns the number of parameters in this module."""
|
||||
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
|
||||
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
|
@ -264,13 +265,11 @@ class ModelSummary:
|
|||
|
||||
@property
|
||||
def total_parameters(self) -> int:
|
||||
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
|
||||
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
|
||||
|
||||
@property
|
||||
def trainable_parameters(self) -> int:
|
||||
return sum(
|
||||
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
|
||||
)
|
||||
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)
|
||||
|
||||
@property
|
||||
def total_layer_params(self) -> int:
|
||||
|
@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
|
|||
return f"{number:,.1f} {labels[index]}"
|
||||
|
||||
|
||||
def _is_lazy_weight_tensor(p: Tensor) -> bool:
|
||||
def _tensor_has_shape(p: Tensor) -> bool:
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
if isinstance(p, UninitializedParameter):
|
||||
# DTensor is a subtype of `UninitializedParameter`, but the shape is known
|
||||
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
|
||||
warning_cache.warn(
|
||||
"The total number of parameters detected may be inaccurate because the model contains"
|
||||
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
|
||||
|
|
|
@ -25,7 +25,7 @@ from lightning.pytorch.utilities.model_summary.model_summary import (
|
|||
NOT_APPLICABLE,
|
||||
LayerSummary,
|
||||
ModelSummary,
|
||||
_is_lazy_weight_tensor,
|
||||
_tensor_has_shape,
|
||||
get_human_readable_count,
|
||||
)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
|
|||
@override
|
||||
def num_parameters(self) -> int:
|
||||
"""Returns the number of parameters in this module."""
|
||||
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
|
||||
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
|
||||
|
||||
@property
|
||||
def average_shard_parameters(self) -> int:
|
||||
|
@ -49,7 +49,7 @@ class DeepSpeedLayerSummary(LayerSummary):
|
|||
def partitioned_size(p: Parameter) -> int:
|
||||
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()
|
||||
|
||||
return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
|
||||
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
|
||||
|
||||
|
||||
class DeepSpeedSummary(ModelSummary):
|
||||
|
@ -71,13 +71,13 @@ class DeepSpeedSummary(ModelSummary):
|
|||
@property
|
||||
@override
|
||||
def total_parameters(self) -> int:
|
||||
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
|
||||
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
|
||||
|
||||
@property
|
||||
@override
|
||||
def trainable_parameters(self) -> int:
|
||||
return sum(
|
||||
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
|
||||
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
|
||||
for p in self._model.parameters()
|
||||
if p.requires_grad
|
||||
)
|
||||
|
|
|
@ -3,7 +3,9 @@ import os
|
|||
from functools import partial
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import lightning.fabric
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
|
||||
|
@ -15,6 +17,7 @@ from lightning.fabric.utilities.distributed import (
|
|||
_gather_all_tensors,
|
||||
_InfiniteBarrier,
|
||||
_init_dist_connection,
|
||||
_is_dtensor,
|
||||
_set_num_threads_if_needed,
|
||||
_suggested_max_num_threads,
|
||||
_sync_ddp,
|
||||
|
@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
|
|||
atexit_mock.reset_mock()
|
||||
_init_dist_connection(LightningEnvironment(), "gloo")
|
||||
atexit_mock.register.assert_not_called()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_is_dtensor(monkeypatch):
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
assert _is_dtensor(Mock(spec=DTensor))
|
||||
assert not _is_dtensor(torch.zeros(2, 2))
|
||||
|
||||
monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
|
||||
assert not _is_dtensor(Mock(spec=DTensor))
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -345,6 +346,18 @@ def test_lazy_model_summary():
|
|||
assert summary.trainable_parameters == 0
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
|
||||
def test_dtensor_model_summary(_):
|
||||
"""Test that the model summary can work with layers that have DTensor parameters."""
|
||||
# We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
|
||||
# would require setting up distributed
|
||||
dtensor_model = UnorderedModel()
|
||||
summary = ModelSummary(dtensor_model)
|
||||
assert summary.total_layer_params > 0
|
||||
assert summary.total_parameters > 0
|
||||
assert summary.trainable_parameters > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
|
||||
def test_max_depth_param(max_depth):
|
||||
"""Test that only the modules up to the desired depth are shown."""
|
||||
|
|
Loading…
Reference in New Issue