Improve typing for logging (#10748)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Carlos Mocholí 2021-11-26 19:12:21 +01:00 committed by GitHub
parent 31bb6e69ca
commit 78face65e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 16 deletions

View File

@ -120,7 +120,6 @@ module = [
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.connectors.logger_connector.result",
"pytorch_lightning.trainer.data_loading",
"pytorch_lightning.trainer.optimizers",
"pytorch_lightning.trainer.supporters",

View File

@ -17,6 +17,16 @@ from typing import Any, Optional, Union
import torch
from torch.nn import Module
try:
from typing_extensions import Self
except ImportError:
# workaround for Python 3.6 and 3.7.
# see https://www.python.org/dev/peps/pep-0673/
from typing import TypeVar
Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")
import pytorch_lightning as pl
@ -47,7 +57,7 @@ class DeviceDtypeModuleMixin(Module):
return device
def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin":
def to(self, *args: Any, **kwargs: Any) -> Self:
"""Moves and/or casts the parameters and buffers.
This can be called as
@ -110,7 +120,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin":
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU
while being optimized.
@ -127,7 +137,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=device)
return super().cuda(device=device)
def cpu(self) -> "DeviceDtypeModuleMixin":
def cpu(self) -> Self:
"""Moves all model parameters and buffers to the CPU.
Returns:
@ -136,7 +146,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=torch.device("cpu"))
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin":
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
@ -148,7 +158,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)
def float(self) -> "DeviceDtypeModuleMixin":
def float(self) -> Self:
"""Casts all floating point parameters and buffers to ``float`` datatype.
Returns:
@ -157,7 +167,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.float)
return super().float()
def double(self) -> "DeviceDtypeModuleMixin":
def double(self) -> Self:
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
@ -166,7 +176,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.double)
return super().double()
def half(self) -> "DeviceDtypeModuleMixin":
def half(self) -> Self:
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:

View File

@ -211,8 +211,10 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum)
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
# this is defined here only because upstream is missing the type annotation
self._forward_cache: Optional[Any] = None
def update(self, value: _IN_METRIC, batch_size: int) -> None:
def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override]
if self.is_tensor:
if not torch.is_floating_point(value):
dtype = torch.get_default_dtype()
@ -225,16 +227,17 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
if self.meta.on_step:
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place
# performance: no need to accumulate on values only logged on_step
if not self.meta.on_epoch:
self.value = self._forward_cache
return
# performance: no need to accumulate on values only logged on_step
if not self.meta.on_epoch:
self.value = self._forward_cache
return
# perform accumulation with reduction
if self.meta.is_mean_reduction:
self.value += value.mean() * batch_size
self.cumulated_batch_size += batch_size
# `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor`
# we could add an assertion, but this is a hot code path
self.cumulated_batch_size += batch_size # type: ignore[operator]
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
elif self.meta.is_sum_reduction:

View File

@ -10,4 +10,4 @@ tensorboard>=2.2.0
torchmetrics>=0.4.1
pyDeprecate==0.3.1
packaging>=17.0
typing-extensions
typing-extensions>=4.0.0