Fix mypy errors attributed to `pytorch_lightning.core.mixins.device_dtype_mixin` (#13704)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
4c35867b61
commit
85b0356ea2
|
@ -52,7 +52,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.stochastic_weight_avg",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.core.decorators",
|
||||
"pytorch_lightning.core.mixins.device_dtype_mixin",
|
||||
"pytorch_lightning.core.module",
|
||||
"pytorch_lightning.core.saving",
|
||||
"pytorch_lightning.demos.boring_classes",
|
||||
|
|
|
@ -16,16 +16,7 @@ from typing import Any, Optional, Union
|
|||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
try:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
# workaround for Python 3.7.
|
||||
# see https://www.python.org/dev/peps/pep-0673/
|
||||
from typing import TypeVar
|
||||
|
||||
Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
@ -57,7 +48,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
|
||||
return device
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> Self:
|
||||
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
|
||||
"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
|
@ -121,7 +112,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=out[0], dtype=out[1])
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
|
||||
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
|
||||
different objects. So it should be called before constructing optimizer if the module will live on GPU
|
||||
while being optimized.
|
||||
|
@ -134,11 +125,11 @@ class DeviceDtypeModuleMixin(Module):
|
|||
Module: self
|
||||
"""
|
||||
if device is None or isinstance(device, int):
|
||||
device = torch.device("cuda", index=device)
|
||||
device = torch.device("cuda", index=(device or 0))
|
||||
self.__update_properties(device=device)
|
||||
return super().cuda(device=device)
|
||||
|
||||
def cpu(self) -> Self:
|
||||
def cpu(self) -> Self: # type: ignore[valid-type]
|
||||
"""Moves all model parameters and buffers to the CPU.
|
||||
|
||||
Returns:
|
||||
|
@ -147,7 +138,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(device=torch.device("cpu"))
|
||||
return super().cpu()
|
||||
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
Arguments:
|
||||
|
@ -159,7 +150,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=dst_type)
|
||||
return super().type(dst_type=dst_type)
|
||||
|
||||
def float(self) -> Self:
|
||||
def float(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -168,7 +159,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.float)
|
||||
return super().float()
|
||||
|
||||
def double(self) -> Self:
|
||||
def double(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -177,7 +168,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self.__update_properties(dtype=torch.double)
|
||||
return super().double()
|
||||
|
||||
def half(self) -> Self:
|
||||
def half(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
Returns:
|
||||
|
|
Loading…
Reference in New Issue