Mark forward_module as required (#16386)

This commit is contained in:
Carlos Mocholí 2023-01-17 13:42:13 +01:00 committed by Luca Antiga
parent 46246c3336
commit 886ad49a55
14 changed files with 27 additions and 203 deletions

View File

@ -64,6 +64,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* env_prefix
* env_parse
- Mark the `forward_module` argument as required ([#16386](https://github.com/Lightning-AI/lightning/pull/16386))
* Removed the deprecated `pl_module` argument from the distributed module wrappers
* Removed the deprecated `pytorch_lightning.overrides.base.unwrap_lightning_module` function
* Removed the `pytorch_lightning.overrides.distributed.LightningDistributedModule` class
* Removed the deprecated `pytorch_lightning.overrides.fairscale.unwrap_lightning_module_sharded` function
* Removed the `pytorch_lightning.overrides.fairscale.LightningDistributedModule` class
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
* Removed the `Trainer(auto_select_gpus=...)` argument
* Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions

View File

@ -1,2 +0,0 @@
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401

View File

@ -11,16 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Union
import torch
import torch.nn as nn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
import pytorch_lightning as pl
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
@ -55,9 +51,7 @@ class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Mod
class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(
self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]]
) -> None:
def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
@ -75,8 +69,6 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
f" got: {forward_module.__class__.__qualname__}"
)
# TODO: In v2.0.0, remove the Optional type from forward_module and remove the assertion
assert forward_module is not None
self._forward_module = forward_module
# set the parameters_to_ignore from LightningModule.
@ -111,47 +103,3 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
if trainer.predicting:
return self._forward_module.predict_step(*inputs, **kwargs)
return self._forward_module(*inputs, **kwargs)
@classmethod
def _validate_init_arguments(
cls,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
# TODO: In v2.0.0, remove this method and mark the forward_module init argument in all subclasses as required
if pl_module is not None:
rank_zero_deprecation(
f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in"
" v2.0.0. Please use `forward_module` instead."
)
elif forward_module is None:
raise ValueError("Argument `forward_module` is required.")
def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule":
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
attributes on the wrapper.
.. deprecated:: v1.8.0
The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the
``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.
Raises:
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
further.
"""
if not _suppress_warning:
rank_zero_deprecation(
"The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the"
" `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = unwrap_lightning_module(model.module)
if isinstance(model, _LightningModuleWrapperBase):
model = model.lightning_module
if isinstance(model, _LightningPrecisionModuleWrapperBase):
model = model.module
if not isinstance(model, pl.LightningModule):
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
return model

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, Optional, Union
from typing import Any, Union
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@ -52,23 +52,15 @@ class LightningParallelModule(_LightningModuleWrapperBase):
)
Args:
pl_module: The module to wrap. See description for `forward_module`.
.. deprecated:: v1.8.0
The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Please use
``forward_module`` instead.
forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module``
pointing to a ``LightningModule`` reference.
"""
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))
super().__init__(forward_module=forward_module)
_ignore_scalar_return_in_dp()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:

View File

@ -12,26 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Iterable, Iterator, List, Optional, Sized, Union
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
import pytorch_lightning as pl
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))
def _find_tensors(

View File

@ -11,21 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import List
import torch.nn as nn
from lightning_utilities.core.imports import package_available
from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_fabric.plugins import Precision
from lightning_fabric.utilities.imports import _IS_WINDOWS
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and package_available("fairscale")
@ -35,37 +27,6 @@ else:
OSS = object
class LightningShardedDataParallel(_LightningModuleWrapperBase):
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
rank_zero_deprecation(
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
" the native version by default."
)
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))
def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
rank_zero_deprecation(
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v2.0.0."
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module
return unwrap_lightning_module(model, _suppress_warning=True)
def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precision, num_nodes: int) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):

View File

@ -64,11 +64,8 @@ log = logging.getLogger(__name__)
class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
) -> None:
self._validate_init_arguments(pl_module, forward_module)
forward_module = pl_module or forward_module
super().__init__(forward_module=forward_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}"

View File

@ -43,8 +43,7 @@ from lightning_fabric.utilities.optimizer import _optimizers_to_device
from lightning_fabric.utilities.seed import reset_seed
from lightning_fabric.utilities.types import ReduceOp
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision import PrecisionPlugin
@ -291,7 +290,7 @@ class DDPStrategy(ParallelStrategy):
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
self.pre_configure_ddp()
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model = self._setup_model(LightningDistributedModule(self.model))
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self) -> Optional[List[int]]:

View File

@ -36,8 +36,7 @@ from lightning_fabric.utilities.distributed import group as _group
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_fabric.utilities.optimizer import _optimizers_to_device
from lightning_fabric.utilities.types import ReduceOp
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
@ -211,7 +210,7 @@ class DDPSpawnStrategy(ParallelStrategy):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model = self._setup_model(LightningDistributedModule(self.model))
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
self._register_ddp_hooks()
# set up optimizers after the wrapped module has been moved to the device

View File

@ -22,7 +22,7 @@ from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning_fabric.utilities.distributed import group as _group
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
@ -123,7 +123,7 @@ class HPUParallelStrategy(DDPStrategy):
if _TORCH_LESSER_EQUAL_1_10_2:
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
self._pre_configure_ddp()
self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore
self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) # type: ignore
if self.root_device.type == "hpu" and self._static_graph:
self._model._set_static_graph() # type: ignore
self._register_ddp_hooks()

View File

@ -28,7 +28,7 @@ from lightning_fabric.plugins.environments import XLAEnvironment
from lightning_fabric.utilities.data import has_len
from lightning_fabric.utilities.optimizer import _optimizers_to_device
from lightning_fabric.utilities.types import _PATH, ReduceOp
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
@ -128,7 +128,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
TPUSpawnStrategy._validate_patched_dataloaders(model)
import torch_xla.distributed.xla_multiprocessing as xmp
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model))
return super().connect(model)
def _configure_launcher(self) -> None:

View File

@ -17,17 +17,12 @@ from unittest import mock
import numpy
import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
from pytorch_lightning.demos.boring_classes import RandomDataset
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
@ -62,51 +57,6 @@ from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.seed import pl_worker_init_function, reset_seed, seed_everything
from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.parametrize(
"wrapper_class",
[
LightningParallelModule,
LightningDistributedModule,
LightningBaguaModule,
],
)
def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class):
with no_warning_call(
DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0"
):
wrapper_class(BoringModel())
with pytest.deprecated_call(
match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0"
):
wrapper_class(pl_module=BoringModel())
@RunIf(fairscale=True)
def test_v1_10_deprecated_fairscale_pl_module_init_parameter():
with no_warning_call(
DeprecationWarning, match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0"
), pytest.deprecated_call(match="FairScale has been deprecated in v1.9.0"):
LightningShardedDataParallel(BoringModel())
with pytest.deprecated_call(
match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0"
):
LightningShardedDataParallel(pl_module=BoringModel())
def test_v1_10_deprecated_unwrap_lightning_module():
with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8.0"):
unwrap_lightning_module(BoringModel())
@RunIf(fairscale=True)
def test_v1_10_deprecated_unwrap_lightning_module_sharded():
with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0"):
unwrap_lightning_module_sharded(BoringModel())
def test_v1_10_deprecated_on_colab_kaggle_func():

View File

@ -13,14 +13,9 @@
# limitations under the License.
import pytest
import torch
from torch.nn import DataParallel
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
@pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase])
@ -30,13 +25,3 @@ def test_wrapper_device_dtype(wrapper_class):
wrapped_model.to(dtype=torch.float16)
assert model.dtype == torch.float16
def test_unwrap_lightning_module():
model = BoringModel()
wrapped_model = _LightningPrecisionModuleWrapperBase(model)
wrapped_model = _LightningModuleWrapperBase(wrapped_model)
wrapped_model = DataParallel(wrapped_model)
with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8.0"):
assert unwrap_lightning_module(wrapped_model) == model

View File

@ -20,7 +20,7 @@ from torch.nn import DataParallel
from pytorch_lightning import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.data_parallel import (
LightningParallelModule,
python_scalar_to_tensor,
@ -30,7 +30,7 @@ from pytorch_lightning.trainer.states import RunningStage
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, LightningDistributedModule])
@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, _LightningModuleWrapperBase])
@pytest.mark.parametrize(
"stage",
[