prune data parallel (#7510)
This commit is contained in:
parent
072ad52b6b
commit
946aee0c7b
|
@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
|
||||
|
||||
|
||||
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
|
||||
|
||||
|
||||
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))
|
||||
|
||||
|
||||
|
|
|
@ -16,12 +16,9 @@ import warnings
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
@ -35,27 +32,6 @@ def _ignore_scalar_return_in_dp():
|
|||
)
|
||||
|
||||
|
||||
class LightningDataParallel(DataParallel):
|
||||
|
||||
def __init__(self, module: LightningModule, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
||||
" From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning
|
||||
)
|
||||
super().__init__(LightningParallelModule(module), *args, **kwargs)
|
||||
|
||||
|
||||
class LightningDistributedDataParallel(DistributedDataParallel):
|
||||
|
||||
def __init__(self, module: LightningModule, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
||||
" From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
|
||||
DeprecationWarning
|
||||
)
|
||||
super().__init__(LightningDistributedModule(module), *args, **kwargs)
|
||||
|
||||
|
||||
class LightningParallelModule(_LightningModuleWrapperBase):
|
||||
"""
|
||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||
|
|
|
@ -14,20 +14,10 @@
|
|||
"""Test deprecated functionality which will be removed in v1.4.0"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDataParallel,
|
||||
LightningDistributedDataParallel,
|
||||
LightningParallelModule,
|
||||
)
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment
|
||||
from tests.deprecated_api import _soft_unimport_module
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
def test_v1_4_0_deprecated_imports():
|
||||
|
@ -48,49 +38,6 @@ def test_v1_4_0_deprecated_imports():
|
|||
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401
|
||||
|
||||
|
||||
class CustomDDPPlugin(DDPSpawnPlugin):
|
||||
|
||||
def configure_ddp(self):
|
||||
# old, deprecated implementation
|
||||
with pytest.deprecated_call(
|
||||
match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.'
|
||||
):
|
||||
self._model = LightningDistributedDataParallel(
|
||||
module=self.lightning_module,
|
||||
device_ids=self.determine_ddp_device_ids(),
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
assert isinstance(self.model, torch.nn.parallel.DistributedDataParallel)
|
||||
assert isinstance(self.model.module, LightningDistributedModule)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, skip_windows=True)
|
||||
def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir):
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
gpus=2,
|
||||
accelerator="ddp_spawn",
|
||||
plugins=[
|
||||
CustomDDPPlugin(
|
||||
parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)],
|
||||
cluster_environment=LightningEnvironment(),
|
||||
)
|
||||
]
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
def test_v1_4_0_deprecated_lightning_data_parallel():
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."):
|
||||
dp_model = LightningDataParallel(model, device_ids=[0])
|
||||
assert isinstance(dp_model, torch.nn.DataParallel)
|
||||
assert isinstance(dp_model.module, LightningParallelModule)
|
||||
|
||||
|
||||
def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
|
Loading…
Reference in New Issue