Add `@override` for files in `src/lightning/pytorch/trainer/connectors` (#18997)
This commit is contained in:
parent
cb23fc2dd4
commit
fc021f1a32
|
@ -19,7 +19,7 @@ import torch
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
from torchmetrics import Metric
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, override
|
||||
|
||||
from lightning.fabric.utilities import move_data_to_device
|
||||
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
|
||||
|
@ -204,6 +204,7 @@ class _ResultMetric(Metric):
|
|||
# this is defined here only because upstream is missing the type annotation
|
||||
self._forward_cache: Optional[Any] = None
|
||||
|
||||
@override
|
||||
def update(self, value: _VALUE, batch_size: int) -> None:
|
||||
if self.is_tensor:
|
||||
value = cast(Tensor, value)
|
||||
|
@ -242,6 +243,7 @@ class _ResultMetric(Metric):
|
|||
self.value = value
|
||||
self._forward_cache = value._forward_cache
|
||||
|
||||
@override
|
||||
def compute(self) -> Tensor:
|
||||
if self.is_tensor:
|
||||
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
|
||||
|
@ -251,6 +253,7 @@ class _ResultMetric(Metric):
|
|||
return value
|
||||
return self.value.compute()
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
if self.is_tensor:
|
||||
super().reset()
|
||||
|
@ -258,6 +261,7 @@ class _ResultMetric(Metric):
|
|||
self.value.reset()
|
||||
self.has_reset = True
|
||||
|
||||
@override
|
||||
def forward(self, value: _VALUE, batch_size: int) -> None:
|
||||
if self.meta.enable_graph:
|
||||
with torch.no_grad():
|
||||
|
@ -266,6 +270,7 @@ class _ResultMetric(Metric):
|
|||
# performance: skip the `torch.no_grad` context manager by calling `update` directly
|
||||
self.update(value, batch_size)
|
||||
|
||||
@override
|
||||
def _wrap_compute(self, compute: Any) -> Any:
|
||||
# Override to avoid syncing - we handle it ourselves.
|
||||
@wraps(compute)
|
||||
|
@ -286,16 +291,19 @@ class _ResultMetric(Metric):
|
|||
|
||||
return wrapped_func
|
||||
|
||||
@override
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
# performance: skip the `torch.nn.Module.__setattr__` checks
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
state = f"{repr(self.meta.name)}, value={self.value}"
|
||||
if self.is_tensor and self.meta.is_mean_reduction:
|
||||
state += f", cumulated_batch_size={self.cumulated_batch_size}"
|
||||
return f"{self.__class__.__name__}({state})"
|
||||
|
||||
@override
|
||||
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
|
||||
d = self.__dict__
|
||||
if _TORCH_GREATER_EQUAL_2_0: # https://github.com/pytorch/pytorch/issues/96198
|
||||
|
|
Loading…
Reference in New Issue