From 85b0356ea22242305bb55a8347ed5733cd806b0d Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 26 Jul 2022 09:38:06 +0200 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.core.mixins.device_dtype_mixin` (#13704) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pyproject.toml | 1 - .../core/mixins/device_dtype_mixin.py | 27 +++++++------------ 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 177410cba7..5a710faf35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", - "pytorch_lightning.core.mixins.device_dtype_mixin", "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 5f6397e456..b12e1cf042 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -16,16 +16,7 @@ from typing import Any, Optional, Union import torch from torch.nn import Module - -try: - from typing_extensions import Self -except ImportError: - # workaround for Python 3.7. - # see https://www.python.org/dev/peps/pep-0673/ - from typing import TypeVar - - Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") - +from typing_extensions import Self import pytorch_lightning as pl @@ -57,7 +48,7 @@ class DeviceDtypeModuleMixin(Module): return device - def to(self, *args: Any, **kwargs: Any) -> Self: + def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] """Moves and/or casts the parameters and buffers. This can be called as @@ -121,7 +112,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -134,11 +125,11 @@ class DeviceDtypeModuleMixin(Module): Module: self """ if device is None or isinstance(device, int): - device = torch.device("cuda", index=device) + device = torch.device("cuda", index=(device or 0)) self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> Self: + def cpu(self) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the CPU. Returns: @@ -147,7 +138,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> Self: + def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type] """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -159,7 +150,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> Self: + def float(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -168,7 +159,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> Self: + def double(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -177,7 +168,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> Self: + def half(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``half`` datatype. Returns: