diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 81630362a7..7e0ef43303 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -19,7 +19,7 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torchmetrics import Metric -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars @@ -204,6 +204,7 @@ class _ResultMetric(Metric): # this is defined here only because upstream is missing the type annotation self._forward_cache: Optional[Any] = None + @override def update(self, value: _VALUE, batch_size: int) -> None: if self.is_tensor: value = cast(Tensor, value) @@ -242,6 +243,7 @@ class _ResultMetric(Metric): self.value = value self._forward_cache = value._forward_cache + @override def compute(self) -> Tensor: if self.is_tensor: value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place @@ -251,6 +253,7 @@ class _ResultMetric(Metric): return value return self.value.compute() + @override def reset(self) -> None: if self.is_tensor: super().reset() @@ -258,6 +261,7 @@ class _ResultMetric(Metric): self.value.reset() self.has_reset = True + @override def forward(self, value: _VALUE, batch_size: int) -> None: if self.meta.enable_graph: with torch.no_grad(): @@ -266,6 +270,7 @@ class _ResultMetric(Metric): # performance: skip the `torch.no_grad` context manager by calling `update` directly self.update(value, batch_size) + @override def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) @@ -286,16 +291,19 @@ class _ResultMetric(Metric): return wrapped_func + @override def __setattr__(self, key: str, value: Any) -> None: # performance: skip the `torch.nn.Module.__setattr__` checks object.__setattr__(self, key, value) + @override def __repr__(self) -> str: state = f"{repr(self.meta.name)}, value={self.value}" if self.is_tensor and self.meta.is_mean_reduction: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" + @override def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric": d = self.__dict__ if _TORCH_GREATER_EQUAL_2_0: # https://github.com/pytorch/pytorch/issues/96198