Improve typing for logging (#10748)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
31bb6e69ca
commit
78face65e8
|
@ -120,7 +120,6 @@ module = [
|
|||
"pytorch_lightning.trainer.connectors.callback_connector",
|
||||
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
||||
"pytorch_lightning.trainer.connectors.data_connector",
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.result",
|
||||
"pytorch_lightning.trainer.data_loading",
|
||||
"pytorch_lightning.trainer.optimizers",
|
||||
"pytorch_lightning.trainer.supporters",
|
||||
|
|
|
@ -17,6 +17,16 @@ from typing import Any, Optional, Union
|
|||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
try:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
# workaround for Python 3.6 and 3.7.
|
||||
# see https://www.python.org/dev/peps/pep-0673/
|
||||
from typing import TypeVar
|
||||
|
||||
Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")
|
||||
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
|
@ -47,7 +57,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
|
||||
return device
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin":
|
||||
def to(self, *args: Any, **kwargs: Any) -> Self:
|
||||
"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
|
@ -110,7 +120,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=out[0], dtype=out[1])
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin":
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
|
||||
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
|
||||
different objects. So it should be called before constructing optimizer if the module will live on GPU
|
||||
while being optimized.
|
||||
|
@ -127,7 +137,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=device)
|
||||
return super().cuda(device=device)
|
||||
|
||||
def cpu(self) -> "DeviceDtypeModuleMixin":
|
||||
def cpu(self) -> Self:
|
||||
"""Moves all model parameters and buffers to the CPU.
|
||||
|
||||
Returns:
|
||||
|
@ -136,7 +146,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=torch.device("cpu"))
|
||||
return super().cpu()
|
||||
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin":
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
|
||||
"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
Arguments:
|
||||
|
@ -148,7 +158,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=dst_type)
|
||||
return super().type(dst_type=dst_type)
|
||||
|
||||
def float(self) -> "DeviceDtypeModuleMixin":
|
||||
def float(self) -> Self:
|
||||
"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -157,7 +167,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.float)
|
||||
return super().float()
|
||||
|
||||
def double(self) -> "DeviceDtypeModuleMixin":
|
||||
def double(self) -> Self:
|
||||
"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -166,7 +176,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.double)
|
||||
return super().double()
|
||||
|
||||
def half(self) -> "DeviceDtypeModuleMixin":
|
||||
def half(self) -> Self:
|
||||
"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -211,8 +211,10 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
|
|||
self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum)
|
||||
if self.meta.is_mean_reduction:
|
||||
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
|
||||
# this is defined here only because upstream is missing the type annotation
|
||||
self._forward_cache: Optional[Any] = None
|
||||
|
||||
def update(self, value: _IN_METRIC, batch_size: int) -> None:
|
||||
def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override]
|
||||
if self.is_tensor:
|
||||
if not torch.is_floating_point(value):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
@ -225,16 +227,17 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
|
|||
|
||||
if self.meta.on_step:
|
||||
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place
|
||||
|
||||
# performance: no need to accumulate on values only logged on_step
|
||||
if not self.meta.on_epoch:
|
||||
self.value = self._forward_cache
|
||||
return
|
||||
# performance: no need to accumulate on values only logged on_step
|
||||
if not self.meta.on_epoch:
|
||||
self.value = self._forward_cache
|
||||
return
|
||||
|
||||
# perform accumulation with reduction
|
||||
if self.meta.is_mean_reduction:
|
||||
self.value += value.mean() * batch_size
|
||||
self.cumulated_batch_size += batch_size
|
||||
# `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor`
|
||||
# we could add an assertion, but this is a hot code path
|
||||
self.cumulated_batch_size += batch_size # type: ignore[operator]
|
||||
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
|
||||
self.value = self.meta.reduce_fx(self.value, value.mean())
|
||||
elif self.meta.is_sum_reduction:
|
||||
|
|
|
@ -10,4 +10,4 @@ tensorboard>=2.2.0
|
|||
torchmetrics>=0.4.1
|
||||
pyDeprecate==0.3.1
|
||||
packaging>=17.0
|
||||
typing-extensions
|
||||
typing-extensions>=4.0.0
|
||||
|
|
Loading…
Reference in New Issue