From 10ce1c0256e562b42cdedbc81beba626ca959f63 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 13 May 2020 05:18:39 +0200 Subject: [PATCH] device property (#1791) * device property * add/copy properties * inherit * rename * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * dtype * prop * pt api Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/core/lightning.py | 10 +- pytorch_lightning/core/properties.py | 156 ++++++++++++++++++ .../trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 10 +- pytorch_lightning/trainer/trainer.py | 2 +- tests/base/model_template.py | 2 + 6 files changed, 170 insertions(+), 12 deletions(-) create mode 100644 pytorch_lightning/core/properties.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 32dd13a779..fac41bda15 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.properties import DeviceDtypeModuleMixin from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,14 +30,11 @@ else: XLA_AVAILABLE = True -class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): +class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - #: Current dtype - self.dtype = torch.FloatTensor - self.exp_save_path = None #: The current epoch @@ -72,8 +70,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): self.hparams = None + #: Current dtype + self._dtype = torch.FloatTensor #: device reference - self.device = None + self._device = torch.device('cpu') def print(self, *args, **kwargs) -> None: r""" diff --git a/pytorch_lightning/core/properties.py b/pytorch_lightning/core/properties.py new file mode 100644 index 0000000000..eb3faf54fa --- /dev/null +++ b/pytorch_lightning/core/properties.py @@ -0,0 +1,156 @@ +from typing import Union, Optional + +import torch + + +class DeviceDtypeModuleMixin(torch.nn.Module): + _device: ... + _dtype: Union[str, torch.dtype] + + @property + def dtype(self) -> Union[str, torch.dtype]: + return self._dtype + + @dtype.setter + def dtype(self, new_dtype: Union[str, torch.dtype]): + # necessary to avoid infinite recursion + raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') + + @property + def device(self) -> Union[str, torch.device]: + return self._device + + @device.setter + def device(self, new_device: Union[str, torch.device]): + # Necessary to avoid infinite recursion + raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') + + def to(self, *args, **kwargs) -> torch.nn.Module: + """Moves and/or casts the parameters and buffers. + + This can be called as + .. function:: to(device=None, dtype=None, non_blocking=False) + .. function:: to(dtype, non_blocking=False) + .. function:: to(tensor, non_blocking=False) + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point desired :attr:`dtype` s. In addition, this method will + only cast the floating point parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + See below for examples. + + Note: + This method modifies the module in-place. + + Args: + device: the desired device of the parameters + and buffers in this module + dtype: the desired floating point type of + the floating point parameters and buffers in this module + tensor: Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + + Returns: + Module: self + + Example:: + >>> class ExampleModule(DeviceDtypeModuleMixin): + ... def __init__(self, weight: torch.Tensor): + ... super().__init__() + ... self.register_buffer('weight', weight) + >>> _ = torch.manual_seed(0) + >>> module = ExampleModule(torch.rand(3, 4)) + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]]) + >>> module.to(torch.double) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float64) + >>> cpu = torch.device('cpu') + >>> module.to(cpu, dtype=torch.half, non_blocking=True) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + >>> module.to(cpu) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + """ + # there is diff nb vars in PT 1.5 + out = torch._C._nn._parse_to(*args, **kwargs) + device = out[0] + dtype = out[1] + if device is not None: + self._device = device + + if dtype is not None: + self._dtype = dtype + + return super().to(*args, **kwargs) + + def cuda(self, device: Optional[int] = None) -> torch.nn.Module: + """Moves all model parameters and buffers to the GPU. + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + Arguments: + device: if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + + self._device = torch.device('cuda', index=device) + return super().cuda(device=device) + + def cpu(self) -> torch.nn.Module: + """Moves all model parameters and buffers to the CPU. + Returns: + Module: self + """ + self._device = torch.device('cpu') + return super().cpu() + + def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: + """Casts all parameters and buffers to :attr:`dst_type`. + + Arguments: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + self._dtype = dst_type + return super().type(dst_type=dst_type) + + def float(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to float datatype. + + Returns: + Module: self + """ + self._dtype = torch.float + return super().float() + + def double(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``double`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.double + return super().double() + + def half(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``half`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.half + return super().half() diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index bd97d5ca33..a7925ab4fd 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -344,7 +344,7 @@ class TrainerDDPMixin(ABC): # copy model to each gpu if self.on_gpu: self.root_gpu = process_idx - self.device = torch.device('cuda', self.root_gpu) + self._device = torch.device('cuda', self.root_gpu) torch.cuda.set_device(self.root_gpu) model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 1bd235ceba..842496402c 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -432,7 +432,7 @@ class TrainerDPMixin(ABC): m.use_tpu = self.use_tpu m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - m.device = self.device + m._device = self._device def transfer_batch_to_tpu(self, batch): return self.__transfer_data_to_device(batch, device='tpu') @@ -484,7 +484,7 @@ class TrainerDPMixin(ABC): def single_gpu_train(self, model): model.cuda(self.root_gpu) - self.device = torch.device('cuda', self.root_gpu) + self._device = torch.device('cuda', self.root_gpu) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -501,7 +501,7 @@ class TrainerDPMixin(ABC): def tpu_train(self, tpu_core_idx, model): # put model on tpu model.to(xm.xla_device()) - self.device = xm.xla_device() + self._device = xm.xla_device() # get the appropriate tpu ranks self.tpu_local_core_rank = xm.get_local_ordinal() @@ -539,7 +539,7 @@ class TrainerDPMixin(ABC): self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) model.cuda(self.root_gpu) - self.device = torch.device('cuda', self.root_gpu) + self._device = torch.device('cuda', self.root_gpu) # hack forward to do autocast for the user model_autocast_original_forward = model.forward @@ -579,7 +579,7 @@ class TrainerDPMixin(ABC): assert self.root_gpu == hvd.local_rank() torch.cuda.set_device(self.root_gpu) model.cuda(self.root_gpu) - self.device = torch.device('cuda', self.root_gpu) + self._device = torch.device('cuda', self.root_gpu) # avoid duplicating progress bar if hvd.rank() != 0 and self.progress_bar_callback is not None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d0b1aa27a5..9639eec8c6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -473,7 +473,7 @@ class Trainer( # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) - self.device = torch.device('cpu') + self._device = torch.device('cpu') # override dist backend when using tpus if self.on_tpu: diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 41e6edc5e3..d530fa4a97 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -34,6 +34,8 @@ class EvalModelTemplate( ): """ This template houses all combinations of model configurations we want to test + + >>> model = EvalModelTemplate() """ def __init__(self, hparams: object = None) -> object: """Pass in parsed HyperOptArgumentParser to the model."""