Fix typing in `pl.overrides.data_parallel` (#10796)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-29 10:58:23 +01:00 committed by GitHub
parent 724a92b065
commit 97e52619ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 39 additions and 28 deletions

View File

@ -78,8 +78,6 @@ module = [
"pytorch_lightning.loops.epoch.training_epoch_loop",
"pytorch_lightning.loops.fit_loop",
"pytorch_lightning.loops.utilities",
"pytorch_lightning.overrides.base",
"pytorch_lightning.overrides.data_parallel",
"pytorch_lightning.overrides.distributed",
"pytorch_lightning.overrides.fairscale",
"pytorch_lightning.plugins.environments.lightning_environment",

View File

@ -14,6 +14,7 @@
from typing import Any, Union
import torch
import torch.nn as nn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
@ -101,10 +102,19 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
pass
def unwrap_lightning_module(wrapped_model) -> "pl.LightningModule":
def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
"""Recursively unwraps a :class:`~pytorch_lightning.core.lightning.LightningModule` by following the
``.module`` attributes on the wrapper.
Raises:
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
further.
"""
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = unwrap_lightning_module(model.module)
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
model = unwrap_lightning_module(model.module)
if not isinstance(model, pl.LightningModule):
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
return model

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numbers
import warnings
from typing import Any
from typing import Any, Union
import torch
@ -23,7 +23,7 @@ from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
def _ignore_scalar_return_in_dp():
def _ignore_scalar_return_in_dp() -> None:
# Users get confused by this warning so we silence it
warnings.filterwarnings(
"ignore",
@ -57,12 +57,12 @@ class LightningParallelModule(_LightningModuleWrapperBase):
super().__init__(pl_module)
_ignore_scalar_return_in_dp()
def forward(self, *inputs, **kwargs):
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
self.update_replica_device_attributes(inputs)
# forward call will redirect to training_step, validation_step, etc.
output = super().forward(*inputs, **kwargs)
def output_transform(data: Any):
def output_transform(data: Any) -> Any:
data = python_scalar_to_tensor(data, self.module.device)
data = unsqueeze_scalar_tensor(data)
return data
@ -101,7 +101,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
)
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
def python_scalar_to_tensor(data: Any, device: Union[str, torch.device] = torch.device("cpu")) -> Any:
"""Converts a Python scalar number to a torch tensor and places it on the given device."""
if isinstance(data, numbers.Number):
data = torch.tensor([data], device=device)

View File

@ -57,8 +57,8 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
return self.root_device.type == "xla" and _XLA_AVAILABLE
@property
def lightning_module(self):
return unwrap_lightning_module(self._model)
def lightning_module(self) -> Optional["pl.LightningModule"]:
return unwrap_lightning_module(self._model) if self._model is not None else None
@property
def global_rank(self) -> int:

View File

@ -101,13 +101,13 @@ class DDPShardedPlugin(DDPPlugin):
return optimizer.state_dict()
@property
def lightning_module(self) -> "pl.LightningModule":
def lightning_module(self) -> Optional["pl.LightningModule"]:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

View File

@ -101,13 +101,13 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
return optimizer.state_dict()
@property
def lightning_module(self) -> "pl.LightningModule":
def lightning_module(self) -> Optional["pl.LightningModule"]:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

View File

@ -177,9 +177,9 @@ class TrainingTypePlugin(ABC):
self._model = new_model
@property
def lightning_module(self) -> "pl.LightningModule":
def lightning_module(self) -> Optional["pl.LightningModule"]:
"""Returns the pure LightningModule without potential wrappers."""
return unwrap_lightning_module(self._model)
return unwrap_lightning_module(self._model) if self._model is not None else None
@property
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:

View File

@ -42,7 +42,8 @@ from tests.helpers.runif import RunIf
)
def test_lightning_wrapper_module_methods(wrapper_class, stage):
"""Test that the LightningWrapper redirects .forward() to the LightningModule methods."""
pl_module = MagicMock()
pl_module = Mock(spec=LightningModule)
pl_module.trainer = Mock()
wrapped_module = wrapper_class(pl_module)
batch = torch.rand(5)

View File

@ -1,5 +1,6 @@
import os
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
@ -256,14 +257,14 @@ def test_configure_ddp(tmpdir):
def test_custom_kwargs_sharded(tmpdir, cls):
"""Tests to ensure that if custom kwargs are passed, they are set correctly."""
plugin = cls(reduce_fp16=True)
plugin.model = Mock(spec=LightningModule)
plugin.model.trainer = Mock()
class_name = "sharded" if isinstance(plugin, DDPShardedPlugin) else "sharded_spawn"
with mock.patch.object(plugin, "_model", autospec=True):
with mock.patch(
f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True
) as mock_sharded:
plugin.configure_ddp()
with mock.patch(
f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True
) as mock_sharded:
plugin.configure_ddp()
args, kwargs = mock_sharded.call_args
assert "reduce_fp16" in kwargs
assert kwargs["reduce_fp16"]
@ -277,12 +278,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
plugin = DDPShardedPlugin(**params)
plugin.num_nodes = num_nodes
plugin.model = Mock(spec=LightningModule)
plugin.model.trainer = Mock()
with mock.patch.object(plugin, "_model", autospec=True):
with mock.patch(
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
) as mock_sharded:
plugin.configure_ddp()
with mock.patch(
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
) as mock_sharded:
plugin.configure_ddp()
args, kwargs = mock_sharded.call_args
assert "reduce_buffer_size" in kwargs