device property (#1791)

* device property

* add/copy properties

* inherit

* rename

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* dtype

* prop

* pt api

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2020-05-13 05:18:39 +02:00 committed by GitHub
parent 8978794730
commit 10ce1c0256
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 170 additions and 12 deletions

View File

@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -29,14 +30,11 @@ else:
XLA_AVAILABLE = True
class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
#: Current dtype
self.dtype = torch.FloatTensor
self.exp_save_path = None
#: The current epoch
@ -72,8 +70,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
self.hparams = None
#: Current dtype
self._dtype = torch.FloatTensor
#: device reference
self.device = None
self._device = torch.device('cpu')
def print(self, *args, **kwargs) -> None:
r"""

View File

@ -0,0 +1,156 @@
from typing import Union, Optional
import torch
class DeviceDtypeModuleMixin(torch.nn.Module):
_device: ...
_dtype: Union[str, torch.dtype]
@property
def dtype(self) -> Union[str, torch.dtype]:
return self._dtype
@dtype.setter
def dtype(self, new_dtype: Union[str, torch.dtype]):
# necessary to avoid infinite recursion
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')
@property
def device(self) -> Union[str, torch.device]:
return self._device
@device.setter
def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')
def to(self, *args, **kwargs) -> torch.nn.Module:
"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note:
This method modifies the module in-place.
Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns:
Module: self
Example::
>>> class ExampleModule(DeviceDtypeModuleMixin):
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
device = out[0]
dtype = out[1]
if device is not None:
self._device = device
if dtype is not None:
self._dtype = dtype
return super().to(*args, **kwargs)
def cuda(self, device: Optional[int] = None) -> torch.nn.Module:
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.
Arguments:
device: if specified, all parameters will be
copied to that device
Returns:
Module: self
"""
self._device = torch.device('cuda', index=device)
return super().cuda(device=device)
def cpu(self) -> torch.nn.Module:
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
self._device = torch.device('cpu')
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module:
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
dst_type (type or string): the desired type
Returns:
Module: self
"""
self._dtype = dst_type
return super().type(dst_type=dst_type)
def float(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to float datatype.
Returns:
Module: self
"""
self._dtype = torch.float
return super().float()
def double(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
Module: self
"""
self._dtype = torch.double
return super().double()
def half(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:
Module: self
"""
self._dtype = torch.half
return super().half()

View File

@ -344,7 +344,7 @@ class TrainerDDPMixin(ABC):
# copy model to each gpu
if self.on_gpu:
self.root_gpu = process_idx
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

View File

@ -432,7 +432,7 @@ class TrainerDPMixin(ABC):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m.device = self.device
m._device = self._device
def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
@ -484,7 +484,7 @@ class TrainerDPMixin(ABC):
def single_gpu_train(self, model):
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
@ -501,7 +501,7 @@ class TrainerDPMixin(ABC):
def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self.device = xm.xla_device()
self._device = xm.xla_device()
# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
@ -539,7 +539,7 @@ class TrainerDPMixin(ABC):
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
@ -579,7 +579,7 @@ class TrainerDPMixin(ABC):
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:

View File

@ -473,7 +473,7 @@ class Trainer(
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
self.device = torch.device('cpu')
self._device = torch.device('cpu')
# override dist backend when using tpus
if self.on_tpu:

View File

@ -34,6 +34,8 @@ class EvalModelTemplate(
):
"""
This template houses all combinations of model configurations we want to test
>>> model = EvalModelTemplate()
"""
def __init__(self, hparams: object = None) -> object:
"""Pass in parsed HyperOptArgumentParser to the model."""