From 97e52619ea753aeec0b37acedd7568182242f8e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 10:58:23 +0100 Subject: [PATCH] Fix typing in `pl.overrides.data_parallel` (#10796) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 2 -- pytorch_lightning/overrides/base.py | 12 +++++++++- pytorch_lightning/overrides/data_parallel.py | 10 ++++---- .../plugins/training_type/parallel.py | 4 ++-- .../plugins/training_type/sharded.py | 4 ++-- .../plugins/training_type/sharded_spawn.py | 4 ++-- .../training_type/training_type_plugin.py | 4 ++-- tests/overrides/test_data_parallel.py | 3 ++- tests/plugins/test_sharded_plugin.py | 24 ++++++++++--------- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f219d8f509..0e56d3a3db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,6 @@ module = [ "pytorch_lightning.loops.epoch.training_epoch_loop", "pytorch_lightning.loops.fit_loop", "pytorch_lightning.loops.utilities", - "pytorch_lightning.overrides.base", - "pytorch_lightning.overrides.data_parallel", "pytorch_lightning.overrides.distributed", "pytorch_lightning.overrides.fairscale", "pytorch_lightning.plugins.environments.lightning_environment", diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index fc22902495..d75628c1a8 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -14,6 +14,7 @@ from typing import Any, Union import torch +import torch.nn as nn from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel @@ -101,10 +102,19 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): pass -def unwrap_lightning_module(wrapped_model) -> "pl.LightningModule": +def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": + """Recursively unwraps a :class:`~pytorch_lightning.core.lightning.LightningModule` by following the + ``.module`` attributes on the wrapper. + + Raises: + TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped + further. + """ model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = unwrap_lightning_module(model.module) if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): model = unwrap_lightning_module(model.module) + if not isinstance(model, pl.LightningModule): + raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") return model diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 615f2c04e7..fd32619ed8 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any +from typing import Any, Union import torch @@ -23,7 +23,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -def _ignore_scalar_return_in_dp(): +def _ignore_scalar_return_in_dp() -> None: # Users get confused by this warning so we silence it warnings.filterwarnings( "ignore", @@ -57,12 +57,12 @@ class LightningParallelModule(_LightningModuleWrapperBase): super().__init__(pl_module) _ignore_scalar_return_in_dp() - def forward(self, *inputs, **kwargs): + def forward(self, *inputs: Any, **kwargs: Any) -> Any: self.update_replica_device_attributes(inputs) # forward call will redirect to training_step, validation_step, etc. output = super().forward(*inputs, **kwargs) - def output_transform(data: Any): + def output_transform(data: Any) -> Any: data = python_scalar_to_tensor(data, self.module.device) data = unsqueeze_scalar_tensor(data) return data @@ -101,7 +101,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) -def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any: +def python_scalar_to_tensor(data: Any, device: Union[str, torch.device] = torch.device("cpu")) -> Any: """Converts a Python scalar number to a torch tensor and places it on the given device.""" if isinstance(data, numbers.Number): data = torch.tensor([data], device=device) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 07ede1ae4f..3a05455b87 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -57,8 +57,8 @@ class ParallelPlugin(TrainingTypePlugin, ABC): return self.root_device.type == "xla" and _XLA_AVAILABLE @property - def lightning_module(self): - return unwrap_lightning_module(self._model) + def lightning_module(self) -> Optional["pl.LightningModule"]: + return unwrap_lightning_module(self._model) if self._model is not None else None @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index e7f57e9c92..280d38bc83 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -101,13 +101,13 @@ class DDPShardedPlugin(DDPPlugin): return optimizer.state_dict() @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) + return unwrap_lightning_module_sharded(self._model) if self._model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12c06b9dde..9f83f0261c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -101,13 +101,13 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): return optimizer.state_dict() @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) + return unwrap_lightning_module_sharded(self._model) if self._model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index be51cc9f92..b8244b9c2e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -177,9 +177,9 @@ class TrainingTypePlugin(ABC): self._model = new_model @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" - return unwrap_lightning_module(self._model) + return unwrap_lightning_module(self._model) if self._model is not None else None @property def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 62a497b310..1e0003486a 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -42,7 +42,8 @@ from tests.helpers.runif import RunIf ) def test_lightning_wrapper_module_methods(wrapper_class, stage): """Test that the LightningWrapper redirects .forward() to the LightningModule methods.""" - pl_module = MagicMock() + pl_module = Mock(spec=LightningModule) + pl_module.trainer = Mock() wrapped_module = wrapper_class(pl_module) batch = torch.rand(5) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 8a55633fb1..f6b58692aa 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,5 +1,6 @@ import os from unittest import mock +from unittest.mock import Mock import pytest import torch @@ -256,14 +257,14 @@ def test_configure_ddp(tmpdir): def test_custom_kwargs_sharded(tmpdir, cls): """Tests to ensure that if custom kwargs are passed, they are set correctly.""" plugin = cls(reduce_fp16=True) - + plugin.model = Mock(spec=LightningModule) + plugin.model.trainer = Mock() class_name = "sharded" if isinstance(plugin, DDPShardedPlugin) else "sharded_spawn" - with mock.patch.object(plugin, "_model", autospec=True): - with mock.patch( - f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True - ) as mock_sharded: - plugin.configure_ddp() + with mock.patch( + f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True + ) as mock_sharded: + plugin.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] @@ -277,12 +278,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe """Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs.""" plugin = DDPShardedPlugin(**params) plugin.num_nodes = num_nodes + plugin.model = Mock(spec=LightningModule) + plugin.model.trainer = Mock() - with mock.patch.object(plugin, "_model", autospec=True): - with mock.patch( - "pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True - ) as mock_sharded: - plugin.configure_ddp() + with mock.patch( + "pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True + ) as mock_sharded: + plugin.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_buffer_size" in kwargs