From 91d98c83454abe9844a468575c62e3a7a503447b Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 12 Jul 2021 18:56:06 +0200 Subject: [PATCH] Fix mypy in utilities.device_dtype_mixin (#8127) --- .../utilities/device_dtype_mixin.py | 34 +++++++++++-------- setup.cfg | 2 ++ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 13f16d9b42..3f68006aba 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch.nn import Module @@ -21,9 +21,9 @@ from torch.nn import Module class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] - def __init__(self): + def __init__(self) -> None: super().__init__() - self._dtype = torch.get_default_dtype() + self._dtype: Union[str, torch.dtype] = torch.get_default_dtype() self._device = torch.device('cpu') @property @@ -31,7 +31,7 @@ class DeviceDtypeModuleMixin(Module): return self._dtype @dtype.setter - def dtype(self, new_dtype: Union[str, torch.dtype]): + def dtype(self, new_dtype: Union[str, torch.dtype]) -> None: # necessary to avoid infinite recursion raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') @@ -45,7 +45,7 @@ class DeviceDtypeModuleMixin(Module): return device - def to(self, *args, **kwargs) -> Module: + def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin': """Moves and/or casts the parameters and buffers. This can be called as @@ -108,7 +108,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtypeModuleMixin': """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will @@ -121,11 +121,13 @@ class DeviceDtypeModuleMixin(Module): Returns: Module: self """ - property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device) + property_device = ( + device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore + ) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now self.__update_properties(device=property_device) return super().cuda(device=device) - def cpu(self) -> Module: + def cpu(self) -> 'DeviceDtypeModuleMixin': """Moves all model parameters and buffers to the CPU. Returns: @@ -134,7 +136,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(device=torch.device('cpu')) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> Module: + def type(self, dst_type: Union[str, torch.dtype]) -> 'DeviceDtypeModuleMixin': """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -146,8 +148,8 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> Module: - """Casts all floating point parameters and buffers to float datatype. + def float(self) -> 'DeviceDtypeModuleMixin': + """Casts all floating point parameters and buffers to ``float`` datatype. Returns: Module: self @@ -155,7 +157,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> Module: + def double(self) -> 'DeviceDtypeModuleMixin': """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -164,7 +166,7 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> Module: + def half(self) -> 'DeviceDtypeModuleMixin': """Casts all floating point parameters and buffers to ``half`` datatype. Returns: @@ -173,9 +175,11 @@ class DeviceDtypeModuleMixin(Module): self.__update_properties(dtype=torch.half) return super().half() - def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + def __update_properties( + self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None + ) -> None: - def apply_fn(module): + def apply_fn(module: Union['DeviceDtypeModuleMixin', Module]) -> None: if not isinstance(module, DeviceDtypeModuleMixin): return if device is not None: diff --git a/setup.cfg b/setup.cfg index 8dae1dc720..fb9f171af0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -183,6 +183,8 @@ ignore_errors = True ignore_errors = True [mypy-pytorch_lightning.utilities.cli] ignore_errors = False +[mypy-pytorch_lightning.utilities.device_dtype_mixin] +ignore_errors = False [mypy-pytorch_lightning.utilities.device_parser] ignore_errors = False [mypy-pytorch_lightning.utilities.parsing]