Use typing forward references (#7770)
* Use typing forward references * Update pytorch_lightning/core/lightning.py
This commit is contained in:
parent
a69beab499
commit
d47173bb72
|
@ -324,42 +324,38 @@ class LightningModule(
|
||||||
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
|
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._results is not None:
|
# set the default depending on the fx_name
|
||||||
# TODO: if logged twice fail with crash
|
on_step = self.__auto_choose_log_on_step(on_step)
|
||||||
|
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
||||||
|
|
||||||
# set the default depending on the fx_name
|
assert self._current_fx_name is not None
|
||||||
on_step = self.__auto_choose_log_on_step(on_step)
|
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
|
||||||
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
|
||||||
|
|
||||||
assert self._current_fx_name is not None
|
# make sure user doesn't introduce logic for multi-dataloaders
|
||||||
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
|
if "/dataloader_idx_" in name:
|
||||||
|
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")
|
||||||
|
|
||||||
# make sure user doesn't introduce logic for multi-dataloaders
|
value = self.__sync(
|
||||||
if "/dataloader_idx_" in name:
|
value,
|
||||||
raise MisconfigurationException(
|
sync_fn=self.trainer.training_type_plugin.reduce,
|
||||||
f"Logged key: {name} should not contain information about dataloader_idx."
|
sync_dist=sync_dist,
|
||||||
)
|
sync_dist_op=sync_dist_op,
|
||||||
|
sync_dist_group=sync_dist_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
value = self.__sync(
|
assert self._results is not None
|
||||||
value,
|
self._results.log(
|
||||||
sync_fn=self.trainer.training_type_plugin.reduce,
|
name,
|
||||||
sync_dist=sync_dist,
|
value,
|
||||||
sync_dist_op=sync_dist_op,
|
prog_bar=prog_bar,
|
||||||
sync_dist_group=sync_dist_group,
|
logger=logger,
|
||||||
device=self.device,
|
on_step=on_step,
|
||||||
)
|
on_epoch=on_epoch,
|
||||||
|
reduce_fx=reduce_fx,
|
||||||
self._results.log(
|
enable_graph=enable_graph,
|
||||||
name,
|
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
|
||||||
value,
|
)
|
||||||
prog_bar=prog_bar,
|
|
||||||
logger=logger,
|
|
||||||
on_step=on_step,
|
|
||||||
on_epoch=on_epoch,
|
|
||||||
reduce_fx=reduce_fx,
|
|
||||||
enable_graph=enable_graph,
|
|
||||||
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
|
|
||||||
)
|
|
||||||
|
|
||||||
def log_dict(
|
def log_dict(
|
||||||
self,
|
self,
|
||||||
|
@ -378,7 +374,7 @@ class LightningModule(
|
||||||
add_dataloader_idx: bool = True,
|
add_dataloader_idx: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log a dictonary of values at once
|
Log a dictionary of values at once
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
|
|
|
@ -15,13 +15,13 @@ import torch
|
||||||
from torch.nn import DataParallel
|
from torch.nn import DataParallel
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||||
|
|
||||||
|
|
||||||
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, pl_module: LightningModule):
|
def __init__(self, pl_module: 'pl.LightningModule') -> None:
|
||||||
"""
|
"""
|
||||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||||
method, either ``training_step``, ``validation_step`` or ``test_step``.
|
method, either ``training_step``, ``validation_step`` or ``test_step``.
|
||||||
|
@ -66,7 +66,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def unwrap_lightning_module(wrapped_model) -> LightningModule:
|
def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule':
|
||||||
model = wrapped_model
|
model = wrapped_model
|
||||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
|
@ -17,7 +17,7 @@ from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||||
from pytorch_lightning.utilities import rank_zero_warn
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
|
@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pl_module: LightningModule):
|
def __init__(self, pl_module: 'pl.LightningModule') -> None:
|
||||||
super().__init__(pl_module)
|
super().__init__(pl_module)
|
||||||
_ignore_scalar_return_in_dp()
|
_ignore_scalar_return_in_dp()
|
||||||
|
|
||||||
|
|
|
@ -18,13 +18,13 @@ import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
|
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||||
|
|
||||||
|
|
||||||
class LightningDistributedModule(_LightningModuleWrapperBase):
|
class LightningDistributedModule(_LightningModuleWrapperBase):
|
||||||
|
|
||||||
def __init__(self, pl_module: LightningModule):
|
def __init__(self, pl_module: 'pl.LightningModule') -> None:
|
||||||
"""
|
"""
|
||||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||||
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
||||||
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
|
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
|
||||||
return amp.master_params(optimizer)
|
return amp.master_params(optimizer)
|
||||||
|
|
||||||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
def dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||||
if not self._connected:
|
if not self._connected:
|
||||||
accelerator = trainer.accelerator
|
accelerator = trainer.accelerator
|
||||||
_, accelerator.optimizers = amp.initialize(
|
_, accelerator.optimizers = amp.initialize(
|
||||||
|
|
|
@ -22,8 +22,8 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
|
||||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||||
|
@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None:
|
||||||
|
|
||||||
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
|
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
|
||||||
|
|
||||||
def __init__(self, pl_module: LightningModule, precision: int):
|
def __init__(self, pl_module: 'pl.LightningModule', precision: int) -> None:
|
||||||
super().__init__(pl_module)
|
super().__init__(pl_module)
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
|
||||||
|
@ -378,7 +378,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
|
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
|
||||||
return distributed_sampler_kwargs
|
return distributed_sampler_kwargs
|
||||||
|
|
||||||
def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]:
|
def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> Tuple[List, List, List]:
|
||||||
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
|
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
|
||||||
# User may have specified config options instead in configure_optimizers, but this is handled
|
# User may have specified config options instead in configure_optimizers, but this is handled
|
||||||
# via `_initialize_deepspeed_train`
|
# via `_initialize_deepspeed_train`
|
||||||
|
|
|
@ -19,7 +19,7 @@ from typing import Any, List, Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||||
|
@ -99,7 +99,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
||||||
return torch_backend
|
return torch_backend
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def configure_sync_batchnorm(model: LightningModule) -> LightningModule:
|
def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule':
|
||||||
"""
|
"""
|
||||||
Add global batchnorm for a model spread across multiple GPUs and nodes.
|
Add global batchnorm for a model spread across multiple GPUs and nodes.
|
||||||
|
|
||||||
|
@ -112,8 +112,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
||||||
Return:
|
Return:
|
||||||
LightningModule with batchnorm layers synchronized between process groups
|
LightningModule with batchnorm layers synchronized between process groups
|
||||||
"""
|
"""
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
return model
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def block_backward_sync(self):
|
def block_backward_sync(self):
|
||||||
|
|
|
@ -125,5 +125,5 @@ def test__validation_step__log(tmpdir):
|
||||||
}
|
}
|
||||||
|
|
||||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||||
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
|
# on purpose DO NOT allow b_step... it's silly to monitor val step metrics
|
||||||
assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}
|
assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}
|
||||||
|
|
Loading…
Reference in New Issue