prune data parallel (#7510)

This commit is contained in:
Jirka Borovec 2021-05-13 07:23:02 +02:00 committed by GitHub
parent 072ad52b6b
commit 946aee0c7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 77 deletions

View File

@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))

View File

@ -16,12 +16,9 @@ import warnings
from typing import Any
import torch
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -35,27 +32,6 @@ def _ignore_scalar_return_in_dp():
)
class LightningDataParallel(DataParallel):
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning
)
super().__init__(LightningParallelModule(module), *args, **kwargs)
class LightningDistributedDataParallel(DistributedDataParallel):
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
DeprecationWarning
)
super().__init__(LightningDistributedModule(module), *args, **kwargs)
class LightningParallelModule(_LightningModuleWrapperBase):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate

View File

@ -14,20 +14,10 @@
"""Test deprecated functionality which will be removed in v1.4.0"""
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.overrides.data_parallel import (
LightningDataParallel,
LightningDistributedDataParallel,
LightningParallelModule,
)
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.plugins.environments import LightningEnvironment
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
def test_v1_4_0_deprecated_imports():
@ -48,49 +38,6 @@ def test_v1_4_0_deprecated_imports():
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401
class CustomDDPPlugin(DDPSpawnPlugin):
def configure_ddp(self):
# old, deprecated implementation
with pytest.deprecated_call(
match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.'
):
self._model = LightningDistributedDataParallel(
module=self.lightning_module,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
assert isinstance(self.model, torch.nn.parallel.DistributedDataParallel)
assert isinstance(self.model.module, LightningDistributedModule)
@RunIf(min_gpus=2, skip_windows=True)
def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
gpus=2,
accelerator="ddp_spawn",
plugins=[
CustomDDPPlugin(
parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)],
cluster_environment=LightningEnvironment(),
)
]
)
trainer.fit(model)
@RunIf(min_gpus=1)
def test_v1_4_0_deprecated_lightning_data_parallel():
model = BoringModel()
with pytest.deprecated_call(match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."):
dp_model = LightningDataParallel(model, device_ids=[0])
assert isinstance(dp_model, torch.nn.DataParallel)
assert isinstance(dp_model.module, LightningParallelModule)
def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir):
class TestModel(BoringModel):