Fix parameter count in ModelSummary when parameters are DTensors (#20163)

This commit is contained in:
awaelchli 2024-08-05 16:57:31 +02:00 committed by GitHub
parent 3de60f4b9f
commit 345450b0c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 52 additions and 13 deletions

View File

@ -14,10 +14,11 @@ import torch.nn.functional as F
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import Self, override
from typing_extensions import Self, TypeGuard, override
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
@ -30,6 +31,8 @@ else:
if TYPE_CHECKING:
from torch.distributed._tensor import DTensor
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy
@ -427,3 +430,11 @@ class _InfiniteBarrier:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)
def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
if _TORCH_GREATER_EQUAL_2_4:
from torch.distributed._tensor import DTensor
return isinstance(tensor, DTensor)
return False

View File

@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
## [2.3.0] - 2024-06-13

View File

@ -25,6 +25,7 @@ from torch import Tensor
from torch.utils.hooks import RemovableHandle
import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache
@ -135,7 +136,7 @@ class LayerSummary:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
@property
def training(self) -> bool:
@ -264,13 +265,11 @@ class ModelSummary:
@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
@property
def trainable_parameters(self) -> int:
return sum(
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
)
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)
@property
def total_layer_params(self) -> int:
@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
return f"{number:,.1f} {labels[index]}"
def _is_lazy_weight_tensor(p: Tensor) -> bool:
def _tensor_has_shape(p: Tensor) -> bool:
from torch.nn.parameter import UninitializedParameter
if isinstance(p, UninitializedParameter):
# DTensor is a subtype of `UninitializedParameter`, but the shape is known
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
warning_cache.warn(
"The total number of parameters detected may be inaccurate because the model contains"
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"

View File

@ -25,7 +25,7 @@ from lightning.pytorch.utilities.model_summary.model_summary import (
NOT_APPLICABLE,
LayerSummary,
ModelSummary,
_is_lazy_weight_tensor,
_tensor_has_shape,
get_human_readable_count,
)
@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
@override
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
@property
def average_shard_parameters(self) -> int:
@ -49,7 +49,7 @@ class DeepSpeedLayerSummary(LayerSummary):
def partitioned_size(p: Parameter) -> int:
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()
return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
class DeepSpeedSummary(ModelSummary):
@ -71,13 +71,13 @@ class DeepSpeedSummary(ModelSummary):
@property
@override
def total_parameters(self) -> int:
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
@property
@override
def trainable_parameters(self) -> int:
return sum(
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
for p in self._model.parameters()
if p.requires_grad
)

View File

@ -3,7 +3,9 @@ import os
from functools import partial
from pathlib import Path
from unittest import mock
from unittest.mock import Mock
import lightning.fabric
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
@ -15,6 +17,7 @@ from lightning.fabric.utilities.distributed import (
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()
@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
assert _is_dtensor(Mock(spec=DTensor))
assert not _is_dtensor(torch.zeros(2, 2))
monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))

View File

@ -13,6 +13,7 @@
# limitations under the License.
from collections import OrderedDict
from typing import Any
from unittest import mock
import pytest
import torch
@ -345,6 +346,18 @@ def test_lazy_model_summary():
assert summary.trainable_parameters == 0
@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
def test_dtensor_model_summary(_):
"""Test that the model summary can work with layers that have DTensor parameters."""
# We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
# would require setting up distributed
dtensor_model = UnorderedModel()
summary = ModelSummary(dtensor_model)
assert summary.total_layer_params > 0
assert summary.total_parameters > 0
assert summary.trainable_parameters > 0
@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown."""