From 49c579f1f08427d3f6c2f38ee6b6d55b8dc2c1c1 Mon Sep 17 00:00:00 2001 From: Kunal Mundada <53429438+AlKun25@users.noreply.github.com> Date: Fri, 5 Mar 2021 01:51:53 +0530 Subject: [PATCH] docstring changes in accelerators (#6327) * docstring changes in accelerators * docstrings moved * whitespaces removed * PEP8 correction[1] --- pytorch_lightning/accelerators/cpu.py | 5 +++++ pytorch_lightning/accelerators/gpu.py | 5 +++++ pytorch_lightning/accelerators/tpu.py | 9 ++++++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index bc0551aefa..f428951b16 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -12,6 +12,11 @@ if TYPE_CHECKING: class CPUAccelerator(Accelerator): def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If AMP is used with CPU, or if the selected device is not CPU. + """ if isinstance(self.precision_plugin, MixedPrecisionPlugin): raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option") diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 785a0cc8dd..dd45e592bd 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -17,6 +17,11 @@ _log = logging.getLogger(__name__) class GPUAccelerator(Accelerator): def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ if "cuda" not in str(self.root_device): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") self.set_nvidia_flags() diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index c36f7287f3..57e65a62f6 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -21,6 +21,11 @@ if TYPE_CHECKING: class TPUAccelerator(Accelerator): def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training. + """ if isinstance(self.precision_plugin, MixedPrecisionPlugin): raise MisconfigurationException( "amp + tpu is not supported. " @@ -31,7 +36,9 @@ class TPUAccelerator(Accelerator): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model) - def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): + def run_optimizer_step( + self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any + ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: