Fix typing in `pl.overrides.data_parallel` (#10796)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
724a92b065
commit
97e52619ea
|
@ -78,8 +78,6 @@ module = [
|
|||
"pytorch_lightning.loops.epoch.training_epoch_loop",
|
||||
"pytorch_lightning.loops.fit_loop",
|
||||
"pytorch_lightning.loops.utilities",
|
||||
"pytorch_lightning.overrides.base",
|
||||
"pytorch_lightning.overrides.data_parallel",
|
||||
"pytorch_lightning.overrides.distributed",
|
||||
"pytorch_lightning.overrides.fairscale",
|
||||
"pytorch_lightning.plugins.environments.lightning_environment",
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
@ -101,10 +102,19 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
pass
|
||||
|
||||
|
||||
def unwrap_lightning_module(wrapped_model) -> "pl.LightningModule":
|
||||
def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
|
||||
"""Recursively unwraps a :class:`~pytorch_lightning.core.lightning.LightningModule` by following the
|
||||
``.module`` attributes on the wrapper.
|
||||
|
||||
Raises:
|
||||
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
|
||||
further.
|
||||
"""
|
||||
model = wrapped_model
|
||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||
model = unwrap_lightning_module(model.module)
|
||||
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
|
||||
model = unwrap_lightning_module(model.module)
|
||||
if not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
|
||||
return model
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import numbers
|
||||
import warnings
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -23,7 +23,7 @@ from pytorch_lightning.utilities import rank_zero_warn
|
|||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def _ignore_scalar_return_in_dp():
|
||||
def _ignore_scalar_return_in_dp() -> None:
|
||||
# Users get confused by this warning so we silence it
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
|
@ -57,12 +57,12 @@ class LightningParallelModule(_LightningModuleWrapperBase):
|
|||
super().__init__(pl_module)
|
||||
_ignore_scalar_return_in_dp()
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||
self.update_replica_device_attributes(inputs)
|
||||
# forward call will redirect to training_step, validation_step, etc.
|
||||
output = super().forward(*inputs, **kwargs)
|
||||
|
||||
def output_transform(data: Any):
|
||||
def output_transform(data: Any) -> Any:
|
||||
data = python_scalar_to_tensor(data, self.module.device)
|
||||
data = unsqueeze_scalar_tensor(data)
|
||||
return data
|
||||
|
@ -101,7 +101,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
|
|||
)
|
||||
|
||||
|
||||
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
|
||||
def python_scalar_to_tensor(data: Any, device: Union[str, torch.device] = torch.device("cpu")) -> Any:
|
||||
"""Converts a Python scalar number to a torch tensor and places it on the given device."""
|
||||
if isinstance(data, numbers.Number):
|
||||
data = torch.tensor([data], device=device)
|
||||
|
|
|
@ -57,8 +57,8 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
|||
return self.root_device.type == "xla" and _XLA_AVAILABLE
|
||||
|
||||
@property
|
||||
def lightning_module(self):
|
||||
return unwrap_lightning_module(self._model)
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
return unwrap_lightning_module(self._model) if self._model is not None else None
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
|
|
|
@ -101,13 +101,13 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
return optimizer.state_dict()
|
||||
|
||||
@property
|
||||
def lightning_module(self) -> "pl.LightningModule":
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
|
||||
raise MisconfigurationException(
|
||||
"`DDPShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
pass
|
||||
|
|
|
@ -101,13 +101,13 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
return optimizer.state_dict()
|
||||
|
||||
@property
|
||||
def lightning_module(self) -> "pl.LightningModule":
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
|
||||
raise MisconfigurationException(
|
||||
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
pass
|
||||
|
|
|
@ -177,9 +177,9 @@ class TrainingTypePlugin(ABC):
|
|||
self._model = new_model
|
||||
|
||||
@property
|
||||
def lightning_module(self) -> "pl.LightningModule":
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
"""Returns the pure LightningModule without potential wrappers."""
|
||||
return unwrap_lightning_module(self._model)
|
||||
return unwrap_lightning_module(self._model) if self._model is not None else None
|
||||
|
||||
@property
|
||||
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
|
||||
|
|
|
@ -42,7 +42,8 @@ from tests.helpers.runif import RunIf
|
|||
)
|
||||
def test_lightning_wrapper_module_methods(wrapper_class, stage):
|
||||
"""Test that the LightningWrapper redirects .forward() to the LightningModule methods."""
|
||||
pl_module = MagicMock()
|
||||
pl_module = Mock(spec=LightningModule)
|
||||
pl_module.trainer = Mock()
|
||||
wrapped_module = wrapper_class(pl_module)
|
||||
|
||||
batch = torch.rand(5)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -256,14 +257,14 @@ def test_configure_ddp(tmpdir):
|
|||
def test_custom_kwargs_sharded(tmpdir, cls):
|
||||
"""Tests to ensure that if custom kwargs are passed, they are set correctly."""
|
||||
plugin = cls(reduce_fp16=True)
|
||||
|
||||
plugin.model = Mock(spec=LightningModule)
|
||||
plugin.model.trainer = Mock()
|
||||
class_name = "sharded" if isinstance(plugin, DDPShardedPlugin) else "sharded_spawn"
|
||||
|
||||
with mock.patch.object(plugin, "_model", autospec=True):
|
||||
with mock.patch(
|
||||
f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True
|
||||
) as mock_sharded:
|
||||
plugin.configure_ddp()
|
||||
with mock.patch(
|
||||
f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True
|
||||
) as mock_sharded:
|
||||
plugin.configure_ddp()
|
||||
args, kwargs = mock_sharded.call_args
|
||||
assert "reduce_fp16" in kwargs
|
||||
assert kwargs["reduce_fp16"]
|
||||
|
@ -277,12 +278,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
|
|||
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
|
||||
plugin = DDPShardedPlugin(**params)
|
||||
plugin.num_nodes = num_nodes
|
||||
plugin.model = Mock(spec=LightningModule)
|
||||
plugin.model.trainer = Mock()
|
||||
|
||||
with mock.patch.object(plugin, "_model", autospec=True):
|
||||
with mock.patch(
|
||||
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
|
||||
) as mock_sharded:
|
||||
plugin.configure_ddp()
|
||||
with mock.patch(
|
||||
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
|
||||
) as mock_sharded:
|
||||
plugin.configure_ddp()
|
||||
args, kwargs = mock_sharded.call_args
|
||||
assert "reduce_buffer_size" in kwargs
|
||||
|
||||
|
|
Loading…
Reference in New Issue