docstring changes in accelerators (#6327)
* docstring changes in accelerators * docstrings moved * whitespaces removed * PEP8 correction[1]
This commit is contained in:
parent
7acbd65bcb
commit
49c579f1f0
|
@ -12,6 +12,11 @@ if TYPE_CHECKING:
|
|||
class CPUAccelerator(Accelerator):
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If AMP is used with CPU, or if the selected device is not CPU.
|
||||
"""
|
||||
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
|
||||
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
|
||||
|
||||
|
|
|
@ -17,6 +17,11 @@ _log = logging.getLogger(__name__)
|
|||
class GPUAccelerator(Accelerator):
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If the selected device is not GPU.
|
||||
"""
|
||||
if "cuda" not in str(self.root_device):
|
||||
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
|
||||
self.set_nvidia_flags()
|
||||
|
|
|
@ -21,6 +21,11 @@ if TYPE_CHECKING:
|
|||
class TPUAccelerator(Accelerator):
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training.
|
||||
"""
|
||||
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
|
||||
raise MisconfigurationException(
|
||||
"amp + tpu is not supported. "
|
||||
|
@ -31,7 +36,9 @@ class TPUAccelerator(Accelerator):
|
|||
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
|
||||
return super().setup(trainer, model)
|
||||
|
||||
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
|
||||
def run_optimizer_step(
|
||||
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
|
||||
) -> None:
|
||||
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
|
|
Loading…
Reference in New Issue