Fix mypy in utilities.device_dtype_mixin (#8127)

This commit is contained in:
Daniel Stancl 2021-07-12 18:56:06 +02:00 committed by GitHub
parent 4f1e7be5ec
commit 91d98c8345
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 15 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import Any, Optional, Union
import torch
from torch.nn import Module
@ -21,9 +21,9 @@ from torch.nn import Module
class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._dtype = torch.get_default_dtype()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device('cpu')
@property
@ -31,7 +31,7 @@ class DeviceDtypeModuleMixin(Module):
return self._dtype
@dtype.setter
def dtype(self, new_dtype: Union[str, torch.dtype]):
def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
# necessary to avoid infinite recursion
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')
@ -45,7 +45,7 @@ class DeviceDtypeModuleMixin(Module):
return device
def to(self, *args, **kwargs) -> Module:
def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin':
"""Moves and/or casts the parameters and buffers.
This can be called as
@ -108,7 +108,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtypeModuleMixin':
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
@ -121,11 +121,13 @@ class DeviceDtypeModuleMixin(Module):
Returns:
Module: self
"""
property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device)
property_device = (
device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore
) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now
self.__update_properties(device=property_device)
return super().cuda(device=device)
def cpu(self) -> Module:
def cpu(self) -> 'DeviceDtypeModuleMixin':
"""Moves all model parameters and buffers to the CPU.
Returns:
@ -134,7 +136,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=torch.device('cpu'))
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> Module:
def type(self, dst_type: Union[str, torch.dtype]) -> 'DeviceDtypeModuleMixin':
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
@ -146,8 +148,8 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)
def float(self) -> Module:
"""Casts all floating point parameters and buffers to float datatype.
def float(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``float`` datatype.
Returns:
Module: self
@ -155,7 +157,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.float)
return super().float()
def double(self) -> Module:
def double(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
@ -164,7 +166,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.double)
return super().double()
def half(self) -> Module:
def half(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:
@ -173,9 +175,11 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.half)
return super().half()
def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module):
def apply_fn(module: Union['DeviceDtypeModuleMixin', Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
return
if device is not None:

View File

@ -183,6 +183,8 @@ ignore_errors = True
ignore_errors = True
[mypy-pytorch_lightning.utilities.cli]
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_dtype_mixin]
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_parser]
ignore_errors = False
[mypy-pytorch_lightning.utilities.parsing]