From d47173bb72cf117bdea3d2eea4f3e0d3b82f6614 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 31 May 2021 09:54:28 +0200 Subject: [PATCH] Use typing forward references (#7770) * Use typing forward references * Update pytorch_lightning/core/lightning.py --- pytorch_lightning/core/lightning.py | 62 +++++++++---------- pytorch_lightning/overrides/base.py | 6 +- pytorch_lightning/overrides/data_parallel.py | 4 +- pytorch_lightning/overrides/distributed.py | 4 +- .../plugins/precision/apex_amp.py | 2 +- .../plugins/training_type/deepspeed.py | 6 +- .../plugins/training_type/parallel.py | 7 +-- tests/accelerators/test_multi_nodes_gpu.py | 2 +- 8 files changed, 44 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cf6e25c54f..3168d36957 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -324,42 +324,38 @@ class LightningModule( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - if self._results is not None: - # TODO: if logged twice fail with crash + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + assert self._current_fx_name is not None + self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) - assert self._current_fx_name is not None - self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") - # make sure user doesn't introduce logic for multi-dataloaders - if "/dataloader_idx_" in name: - raise MisconfigurationException( - f"Logged key: {name} should not contain information about dataloader_idx." - ) + value = self.__sync( + value, + sync_fn=self.trainer.training_type_plugin.reduce, + sync_dist=sync_dist, + sync_dist_op=sync_dist_op, + sync_dist_group=sync_dist_group, + device=self.device, + ) - value = self.__sync( - value, - sync_fn=self.trainer.training_type_plugin.reduce, - sync_dist=sync_dist, - sync_dist_op=sync_dist_op, - sync_dist_group=sync_dist_group, - device=self.device, - ) - - self._results.log( - name, - value, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - ) + assert self._results is not None + self._results.log( + name, + value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), + ) def log_dict( self, @@ -378,7 +374,7 @@ class LightningModule( add_dataloader_idx: bool = True, ) -> None: """ - Log a dictonary of values at once + Log a dictionary of values at once Example:: diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 88e8ed6375..e086779bec 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -15,13 +15,13 @@ import torch from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step`` or ``test_step``. @@ -66,7 +66,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): pass -def unwrap_lightning_module(wrapped_model) -> LightningModule: +def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule': model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = model.module diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 3d6e527ef9..57919db6ab 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -17,7 +17,7 @@ from typing import Any import torch -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): """ - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: super().__init__(pl_module) _ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 559e1161ce..d4b1e6ed22 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -18,13 +18,13 @@ import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 71c2119e73..aa3aad7689 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -39,7 +39,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): def master_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) - def dispatch(self, trainer: "pl.Trainer") -> None: + def dispatch(self, trainer: 'pl.Trainer') -> None: if not self._connected: accelerator = trainer.accelerator _, accelerator.optimizers = amp.initialize( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8dd04aafa6..3481986f21 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -22,8 +22,8 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule, precision: int): + def __init__(self, pl_module: 'pl.LightningModule', precision: int) -> None: super().__init__(pl_module) self.precision = precision @@ -378,7 +378,7 @@ class DeepSpeedPlugin(DDPPlugin): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> Tuple[List, List, List]: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index a8028e5be1..09e48a760e 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -19,7 +19,7 @@ from typing import Any, List, Optional import torch from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin @@ -99,7 +99,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC): return torch_backend @staticmethod - def configure_sync_batchnorm(model: LightningModule) -> LightningModule: + def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. @@ -112,8 +112,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC): Return: LightningModule with batchnorm layers synchronized between process groups """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - return model + return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @contextmanager def block_backward_sync(self): diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index cae257666e..463307ead8 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -125,5 +125,5 @@ def test__validation_step__log(tmpdir): } # we don't want to enable val metrics during steps because it is not something that users should do - # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + # on purpose DO NOT allow b_step... it's silly to monitor val step metrics assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}