Fix mypy in utilities.device_dtype_mixin (#8127)
This commit is contained in:
parent
4f1e7be5ec
commit
91d98c8345
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
@ -21,9 +21,9 @@ from torch.nn import Module
|
|||
class DeviceDtypeModuleMixin(Module):
|
||||
__jit_unused_properties__ = ['device', 'dtype']
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dtype = torch.get_default_dtype()
|
||||
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
@property
|
||||
|
@ -31,7 +31,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, new_dtype: Union[str, torch.dtype]):
|
||||
def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
|
||||
# necessary to avoid infinite recursion
|
||||
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')
|
||||
|
||||
|
@ -45,7 +45,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
|
||||
return device
|
||||
|
||||
def to(self, *args, **kwargs) -> Module:
|
||||
def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin':
|
||||
"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
|
@ -108,7 +108,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=out[0], dtype=out[1])
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtypeModuleMixin':
|
||||
"""Moves all model parameters and buffers to the GPU.
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
it should be called before constructing optimizer if the module will
|
||||
|
@ -121,11 +121,13 @@ class DeviceDtypeModuleMixin(Module):
|
|||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device)
|
||||
property_device = (
|
||||
device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore
|
||||
) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now
|
||||
self.__update_properties(device=property_device)
|
||||
return super().cuda(device=device)
|
||||
|
||||
def cpu(self) -> Module:
|
||||
def cpu(self) -> 'DeviceDtypeModuleMixin':
|
||||
"""Moves all model parameters and buffers to the CPU.
|
||||
|
||||
Returns:
|
||||
|
@ -134,7 +136,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=torch.device('cpu'))
|
||||
return super().cpu()
|
||||
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Module:
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> 'DeviceDtypeModuleMixin':
|
||||
"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
Arguments:
|
||||
|
@ -146,8 +148,8 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=dst_type)
|
||||
return super().type(dst_type=dst_type)
|
||||
|
||||
def float(self) -> Module:
|
||||
"""Casts all floating point parameters and buffers to float datatype.
|
||||
def float(self) -> 'DeviceDtypeModuleMixin':
|
||||
"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
|
@ -155,7 +157,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.float)
|
||||
return super().float()
|
||||
|
||||
def double(self) -> Module:
|
||||
def double(self) -> 'DeviceDtypeModuleMixin':
|
||||
"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -164,7 +166,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.double)
|
||||
return super().double()
|
||||
|
||||
def half(self) -> Module:
|
||||
def half(self) -> 'DeviceDtypeModuleMixin':
|
||||
"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -173,9 +175,11 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.half)
|
||||
return super().half()
|
||||
|
||||
def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
def __update_properties(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
|
||||
) -> None:
|
||||
|
||||
def apply_fn(module):
|
||||
def apply_fn(module: Union['DeviceDtypeModuleMixin', Module]) -> None:
|
||||
if not isinstance(module, DeviceDtypeModuleMixin):
|
||||
return
|
||||
if device is not None:
|
||||
|
|
|
@ -183,6 +183,8 @@ ignore_errors = True
|
|||
ignore_errors = True
|
||||
[mypy-pytorch_lightning.utilities.cli]
|
||||
ignore_errors = False
|
||||
[mypy-pytorch_lightning.utilities.device_dtype_mixin]
|
||||
ignore_errors = False
|
||||
[mypy-pytorch_lightning.utilities.device_parser]
|
||||
ignore_errors = False
|
||||
[mypy-pytorch_lightning.utilities.parsing]
|
||||
|
|
Loading…
Reference in New Issue