Fix mypy errors attributed to `pytorch_lightning.core.mixins.device_dtype_mixin` (#13704)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-07-26 09:38:06 +02:00 committed by GitHub
parent 4c35867b61
commit 85b0356ea2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 19 deletions

View File

@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.stochastic_weight_avg",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.mixins.device_dtype_mixin",
"pytorch_lightning.core.module",
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",

View File

@ -16,16 +16,7 @@ from typing import Any, Optional, Union
import torch
from torch.nn import Module
try:
from typing_extensions import Self
except ImportError:
# workaround for Python 3.7.
# see https://www.python.org/dev/peps/pep-0673/
from typing import TypeVar
Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")
from typing_extensions import Self
import pytorch_lightning as pl
@ -57,7 +48,7 @@ class DeviceDtypeModuleMixin(Module):
return device
def to(self, *args: Any, **kwargs: Any) -> Self:
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
"""Moves and/or casts the parameters and buffers.
This can be called as
@ -121,7 +112,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU
while being optimized.
@ -134,11 +125,11 @@ class DeviceDtypeModuleMixin(Module):
Module: self
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=device)
device = torch.device("cuda", index=(device or 0))
self.__update_properties(device=device)
return super().cuda(device=device)
def cpu(self) -> Self:
def cpu(self) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the CPU.
Returns:
@ -147,7 +138,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(device=torch.device("cpu"))
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
@ -159,7 +150,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)
def float(self) -> Self:
def float(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``float`` datatype.
Returns:
@ -168,7 +159,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.float)
return super().float()
def double(self) -> Self:
def double(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
@ -177,7 +168,7 @@ class DeviceDtypeModuleMixin(Module):
self.__update_properties(dtype=torch.double)
return super().double()
def half(self) -> Self:
def half(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns: