diff --git a/pyproject.toml b/pyproject.toml index c266e0684e..2471be131c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ module = [ "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.data_connector", - "pytorch_lightning.trainer.connectors.logger_connector.result", "pytorch_lightning.trainer.data_loading", "pytorch_lightning.trainer.optimizers", "pytorch_lightning.trainer.supporters", diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e8b122989c..d902958b9b 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,16 @@ from typing import Any, Optional, Union import torch from torch.nn import Module +try: + from typing_extensions import Self +except ImportError: + # workaround for Python 3.6 and 3.7. + # see https://www.python.org/dev/peps/pep-0673/ + from typing import TypeVar + + Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") + + import pytorch_lightning as pl @@ -47,7 +57,7 @@ class DeviceDtypeModuleMixin(Module): return device - def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": + def to(self, *args: Any, **kwargs: Any) -> Self: """Moves and/or casts the parameters and buffers. This can be called as @@ -110,7 +120,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin": + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -127,7 +137,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> "DeviceDtypeModuleMixin": + def cpu(self) -> Self: """Moves all model parameters and buffers to the CPU. Returns: @@ -136,7 +146,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin": + def type(self, dst_type: Union[str, torch.dtype]) -> Self: """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -148,7 +158,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> "DeviceDtypeModuleMixin": + def float(self) -> Self: """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -157,7 +167,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> "DeviceDtypeModuleMixin": + def double(self) -> Self: """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -166,7 +176,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> "DeviceDtypeModuleMixin": + def half(self) -> Self: """Casts all floating point parameters and buffers to ``half`` datatype. Returns: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e10360a5fb..1c27b75854 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -211,8 +211,10 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin): self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) + # this is defined here only because upstream is missing the type annotation + self._forward_cache: Optional[Any] = None - def update(self, value: _IN_METRIC, batch_size: int) -> None: + def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override] if self.is_tensor: if not torch.is_floating_point(value): dtype = torch.get_default_dtype() @@ -225,16 +227,17 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin): if self.meta.on_step: self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place - - # performance: no need to accumulate on values only logged on_step - if not self.meta.on_epoch: - self.value = self._forward_cache - return + # performance: no need to accumulate on values only logged on_step + if not self.meta.on_epoch: + self.value = self._forward_cache + return # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size + # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` + # we could add an assertion, but this is a hot code path + self.cumulated_batch_size += batch_size # type: ignore[operator] elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: diff --git a/requirements.txt b/requirements.txt index 34879d9290..94b7151d73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ tensorboard>=2.2.0 torchmetrics>=0.4.1 pyDeprecate==0.3.1 packaging>=17.0 -typing-extensions +typing-extensions>=4.0.0