Use typing forward references (#7770)

* Use typing forward references

* Update pytorch_lightning/core/lightning.py
This commit is contained in:
Carlos Mocholí 2021-05-31 09:54:28 +02:00 committed by GitHub
parent a69beab499
commit d47173bb72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 49 deletions

View File

@ -324,42 +324,38 @@ class LightningModule(
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
) )
if self._results is not None: # set the default depending on the fx_name
# TODO: if logged twice fail with crash on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
# set the default depending on the fx_name assert self._current_fx_name is not None
on_step = self.__auto_choose_log_on_step(on_step) self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
assert self._current_fx_name is not None # make sure user doesn't introduce logic for multi-dataloaders
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) if "/dataloader_idx_" in name:
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")
# make sure user doesn't introduce logic for multi-dataloaders value = self.__sync(
if "/dataloader_idx_" in name: value,
raise MisconfigurationException( sync_fn=self.trainer.training_type_plugin.reduce,
f"Logged key: {name} should not contain information about dataloader_idx." sync_dist=sync_dist,
) sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
value = self.__sync( assert self._results is not None
value, self._results.log(
sync_fn=self.trainer.training_type_plugin.reduce, name,
sync_dist=sync_dist, value,
sync_dist_op=sync_dist_op, prog_bar=prog_bar,
sync_dist_group=sync_dist_group, logger=logger,
device=self.device, on_step=on_step,
) on_epoch=on_epoch,
reduce_fx=reduce_fx,
self._results.log( enable_graph=enable_graph,
name, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
value, )
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
)
def log_dict( def log_dict(
self, self,
@ -378,7 +374,7 @@ class LightningModule(
add_dataloader_idx: bool = True, add_dataloader_idx: bool = True,
) -> None: ) -> None:
""" """
Log a dictonary of values at once Log a dictionary of values at once
Example:: Example::

View File

@ -15,13 +15,13 @@ import torch
from torch.nn import DataParallel from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.lightning import LightningModule import pytorch_lightning as pl
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: LightningModule): def __init__(self, pl_module: 'pl.LightningModule') -> None:
""" """
Wraps the user's LightningModule and redirects the forward call to the appropriate Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``. method, either ``training_step``, ``validation_step`` or ``test_step``.
@ -66,7 +66,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
pass pass
def unwrap_lightning_module(wrapped_model) -> LightningModule: def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule':
model = wrapped_model model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)): if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module model = model.module

View File

@ -17,7 +17,7 @@ from typing import Any
import torch import torch
from pytorch_lightning.core.lightning import LightningModule import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
""" """
def __init__(self, pl_module: LightningModule): def __init__(self, pl_module: 'pl.LightningModule') -> None:
super().__init__(pl_module) super().__init__(pl_module)
_ignore_scalar_return_in_dp() _ignore_scalar_return_in_dp()

View File

@ -18,13 +18,13 @@ import torch
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler from torch.utils.data import BatchSampler, DistributedSampler, Sampler
from pytorch_lightning.core.lightning import LightningModule import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
class LightningDistributedModule(_LightningModuleWrapperBase): class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: LightningModule): def __init__(self, pl_module: 'pl.LightningModule') -> None:
""" """
Wraps the user's LightningModule and redirects the forward call to the appropriate Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.

View File

@ -39,7 +39,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
def master_params(self, optimizer: Optimizer) -> _PARAMETERS: def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
return amp.master_params(optimizer) return amp.master_params(optimizer)
def dispatch(self, trainer: "pl.Trainer") -> None: def dispatch(self, trainer: 'pl.Trainer') -> None:
if not self._connected: if not self._connected:
accelerator = trainer.accelerator accelerator = trainer.accelerator
_, accelerator.optimizers = amp.initialize( _, accelerator.optimizers = amp.initialize(

View File

@ -22,8 +22,8 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None:
class LightningDeepSpeedModule(_LightningModuleWrapperBase): class LightningDeepSpeedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: LightningModule, precision: int): def __init__(self, pl_module: 'pl.LightningModule', precision: int) -> None:
super().__init__(pl_module) super().__init__(pl_module)
self.precision = precision self.precision = precision
@ -378,7 +378,7 @@ class DeepSpeedPlugin(DDPPlugin):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs return distributed_sampler_kwargs
def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config. # Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled # User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train` # via `_initialize_deepspeed_train`

View File

@ -19,7 +19,7 @@ from typing import Any, List, Optional
import torch import torch
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.lightning import LightningModule import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
@ -99,7 +99,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
return torch_backend return torch_backend
@staticmethod @staticmethod
def configure_sync_batchnorm(model: LightningModule) -> LightningModule: def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule':
""" """
Add global batchnorm for a model spread across multiple GPUs and nodes. Add global batchnorm for a model spread across multiple GPUs and nodes.
@ -112,8 +112,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
Return: Return:
LightningModule with batchnorm layers synchronized between process groups LightningModule with batchnorm layers synchronized between process groups
""" """
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
@contextmanager @contextmanager
def block_backward_sync(self): def block_backward_sync(self):

View File

@ -125,5 +125,5 @@ def test__validation_step__log(tmpdir):
} }
# we don't want to enable val metrics during steps because it is not something that users should do # we don't want to enable val metrics during steps because it is not something that users should do
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics # on purpose DO NOT allow b_step... it's silly to monitor val step metrics
assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}