Add `@override` for files in `src/lightning/pytorch/trainer/connectors` (#18997)

This commit is contained in:
Victor Prins 2023-11-14 16:24:04 +01:00 committed by GitHub
parent cb23fc2dd4
commit fc021f1a32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 1 deletions

View File

@ -19,7 +19,7 @@ import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torchmetrics import Metric
from typing_extensions import TypedDict
from typing_extensions import TypedDict, override
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
@ -204,6 +204,7 @@ class _ResultMetric(Metric):
# this is defined here only because upstream is missing the type annotation
self._forward_cache: Optional[Any] = None
@override
def update(self, value: _VALUE, batch_size: int) -> None:
if self.is_tensor:
value = cast(Tensor, value)
@ -242,6 +243,7 @@ class _ResultMetric(Metric):
self.value = value
self._forward_cache = value._forward_cache
@override
def compute(self) -> Tensor:
if self.is_tensor:
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
@ -251,6 +253,7 @@ class _ResultMetric(Metric):
return value
return self.value.compute()
@override
def reset(self) -> None:
if self.is_tensor:
super().reset()
@ -258,6 +261,7 @@ class _ResultMetric(Metric):
self.value.reset()
self.has_reset = True
@override
def forward(self, value: _VALUE, batch_size: int) -> None:
if self.meta.enable_graph:
with torch.no_grad():
@ -266,6 +270,7 @@ class _ResultMetric(Metric):
# performance: skip the `torch.no_grad` context manager by calling `update` directly
self.update(value, batch_size)
@override
def _wrap_compute(self, compute: Any) -> Any:
# Override to avoid syncing - we handle it ourselves.
@wraps(compute)
@ -286,16 +291,19 @@ class _ResultMetric(Metric):
return wrapped_func
@override
def __setattr__(self, key: str, value: Any) -> None:
# performance: skip the `torch.nn.Module.__setattr__` checks
object.__setattr__(self, key, value)
@override
def __repr__(self) -> str:
state = f"{repr(self.meta.name)}, value={self.value}"
if self.is_tensor and self.meta.is_mean_reduction:
state += f", cumulated_batch_size={self.cumulated_batch_size}"
return f"{self.__class__.__name__}({state})"
@override
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
d = self.__dict__
if _TORCH_GREATER_EQUAL_2_0: # https://github.com/pytorch/pytorch/issues/96198