From 345450b0c3a3828c675488c6d41dd8bddb3dd008 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 16:57:31 +0200 Subject: [PATCH] Fix parameter count in ModelSummary when parameters are DTensors (#20163) --- src/lightning/fabric/utilities/distributed.py | 13 ++++++++++++- src/lightning/pytorch/CHANGELOG.md | 1 + .../utilities/model_summary/model_summary.py | 14 +++++++------- .../model_summary/model_summary_deepspeed.py | 10 +++++----- tests/tests_fabric/utilities/test_distributed.py | 14 ++++++++++++++ .../tests_pytorch/utilities/test_model_summary.py | 13 +++++++++++++ 6 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 75b2f7c580..0e6c52dfb0 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -14,10 +14,11 @@ import torch.nn.functional as F from lightning_utilities.core.imports import package_available from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import Self, override +from typing_extensions import Self, TypeGuard, override from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.rank_zero import rank_zero_info from lightning.fabric.utilities.types import _PATH, ReduceOp @@ -30,6 +31,8 @@ else: if TYPE_CHECKING: + from torch.distributed._tensor import DTensor + from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.strategies import Strategy @@ -427,3 +430,11 @@ class _InfiniteBarrier: self.barrier() if self.group is not None: torch.distributed.destroy_process_group(self.group) + + +def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]: + if _TORCH_GREATER_EQUAL_2_4: + from torch.distributed._tensor import DTensor + + return isinstance(tensor, DTensor) + return False diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 4d8eebf134..eba67ebffc 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814)) +- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163)) ## [2.3.0] - 2024-06-13 diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 0f48bee191..c40dc94568 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -25,6 +25,7 @@ from torch import Tensor from torch.utils.hooks import RemovableHandle import lightning.pytorch as pl +from lightning.fabric.utilities.distributed import _is_dtensor from lightning.pytorch.utilities.model_helpers import _ModuleMode from lightning.pytorch.utilities.rank_zero import WarningCache @@ -135,7 +136,7 @@ class LayerSummary: @property def num_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) @property def training(self) -> bool: @@ -264,13 +265,11 @@ class ModelSummary: @property def total_parameters(self) -> int: - return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters()) @property def trainable_parameters(self) -> int: - return sum( - p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad - ) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad) @property def total_layer_params(self) -> int: @@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str: return f"{number:,.1f} {labels[index]}" -def _is_lazy_weight_tensor(p: Tensor) -> bool: +def _tensor_has_shape(p: Tensor) -> bool: from torch.nn.parameter import UninitializedParameter - if isinstance(p, UninitializedParameter): + # DTensor is a subtype of `UninitializedParameter`, but the shape is known + if isinstance(p, UninitializedParameter) and not _is_dtensor(p): warning_cache.warn( "The total number of parameters detected may be inaccurate because the model contains" " an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`" diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index c3c9cfe982..57d9ae5024 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -25,7 +25,7 @@ from lightning.pytorch.utilities.model_summary.model_summary import ( NOT_APPLICABLE, LayerSummary, ModelSummary, - _is_lazy_weight_tensor, + _tensor_has_shape, get_human_readable_count, ) @@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary): @override def num_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) @property def average_shard_parameters(self) -> int: @@ -49,7 +49,7 @@ class DeepSpeedLayerSummary(LayerSummary): def partitioned_size(p: Parameter) -> int: return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel() - return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) class DeepSpeedSummary(ModelSummary): @@ -71,13 +71,13 @@ class DeepSpeedSummary(ModelSummary): @property @override def total_parameters(self) -> int: - return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) + return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters()) @property @override def trainable_parameters(self) -> int: return sum( - deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 + deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad ) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 2c30b3aa62..cc6c23bddb 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -3,7 +3,9 @@ import os from functools import partial from pathlib import Path from unittest import mock +from unittest.mock import Mock +import lightning.fabric import pytest import torch from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator @@ -15,6 +17,7 @@ from lightning.fabric.utilities.distributed import ( _gather_all_tensors, _InfiniteBarrier, _init_dist_connection, + _is_dtensor, _set_num_threads_if_needed, _suggested_max_num_threads, _sync_ddp, @@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.reset_mock() _init_dist_connection(LightningEnvironment(), "gloo") atexit_mock.register.assert_not_called() + + +@RunIf(min_torch="2.4") +def test_is_dtensor(monkeypatch): + from torch.distributed._tensor import DTensor + + assert _is_dtensor(Mock(spec=DTensor)) + assert not _is_dtensor(torch.zeros(2, 2)) + + monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False) + assert not _is_dtensor(Mock(spec=DTensor)) diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index 00fdf77d4c..cced6546aa 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -13,6 +13,7 @@ # limitations under the License. from collections import OrderedDict from typing import Any +from unittest import mock import pytest import torch @@ -345,6 +346,18 @@ def test_lazy_model_summary(): assert summary.trainable_parameters == 0 +@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True) +def test_dtensor_model_summary(_): + """Test that the model summary can work with layers that have DTensor parameters.""" + # We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors + # would require setting up distributed + dtensor_model = UnorderedModel() + summary = ModelSummary(dtensor_model) + assert summary.total_layer_params > 0 + assert summary.total_parameters > 0 + assert summary.trainable_parameters > 0 + + @pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999]) def test_max_depth_param(max_depth): """Test that only the modules up to the desired depth are shown."""