docstring changes in accelerators (#6327)

* docstring changes in accelerators

* docstrings moved

* whitespaces removed

* PEP8 correction[1]
This commit is contained in:
Kunal Mundada 2021-03-05 01:51:53 +05:30 committed by GitHub
parent 7acbd65bcb
commit 49c579f1f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 1 deletions

View File

@ -12,6 +12,11 @@ if TYPE_CHECKING:
class CPUAccelerator(Accelerator):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with CPU, or if the selected device is not CPU.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")

View File

@ -17,6 +17,11 @@ _log = logging.getLogger(__name__)
class GPUAccelerator(Accelerator):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()

View File

@ -21,6 +21,11 @@ if TYPE_CHECKING:
class TPUAccelerator(Accelerator):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
@ -31,7 +36,9 @@ class TPUAccelerator(Accelerator):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: