From 946aee0c7b367754d857c8bea33f89440b294469 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 13 May 2021 07:23:02 +0200 Subject: [PATCH] prune data parallel (#7510) --- CHANGELOG.md | 3 ++ pytorch_lightning/overrides/data_parallel.py | 24 --------- tests/deprecated_api/test_remove_1-4.py | 53 -------------------- 3 files changed, 3 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7635b6ad1b..ede96bd1d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499)) +- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510)) + + - Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502)) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 272f4c6750..3d6e527ef9 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -16,12 +16,9 @@ import warnings from typing import Any import torch -from torch.nn import DataParallel -from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -35,27 +32,6 @@ def _ignore_scalar_return_in_dp(): ) -class LightningDataParallel(DataParallel): - - def __init__(self, module: LightningModule, *args, **kwargs): - warnings.warn( - "The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." - " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning - ) - super().__init__(LightningParallelModule(module), *args, **kwargs) - - -class LightningDistributedDataParallel(DistributedDataParallel): - - def __init__(self, module: LightningModule, *args, **kwargs): - warnings.warn( - "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." - " From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.", - DeprecationWarning - ) - super().__init__(LightningDistributedModule(module), *args, **kwargs) - - class LightningParallelModule(_LightningModuleWrapperBase): """ Wraps the user's LightningModule and redirects the forward call to the appropriate diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 851618f644..83220e99bb 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -14,20 +14,10 @@ """Test deprecated functionality which will be removed in v1.4.0""" import pytest -import torch from pytorch_lightning import Trainer -from pytorch_lightning.overrides.data_parallel import ( - LightningDataParallel, - LightningDistributedDataParallel, - LightningParallelModule, -) -from pytorch_lightning.overrides.distributed import LightningDistributedModule -from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.plugins.environments import LightningEnvironment from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel -from tests.helpers.runif import RunIf def test_v1_4_0_deprecated_imports(): @@ -48,49 +38,6 @@ def test_v1_4_0_deprecated_imports(): from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401 -class CustomDDPPlugin(DDPSpawnPlugin): - - def configure_ddp(self): - # old, deprecated implementation - with pytest.deprecated_call( - match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.' - ): - self._model = LightningDistributedDataParallel( - module=self.lightning_module, - device_ids=self.determine_ddp_device_ids(), - **self._ddp_kwargs, - ) - assert isinstance(self.model, torch.nn.parallel.DistributedDataParallel) - assert isinstance(self.model.module, LightningDistributedModule) - - -@RunIf(min_gpus=2, skip_windows=True) -def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir): - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - gpus=2, - accelerator="ddp_spawn", - plugins=[ - CustomDDPPlugin( - parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], - cluster_environment=LightningEnvironment(), - ) - ] - ) - trainer.fit(model) - - -@RunIf(min_gpus=1) -def test_v1_4_0_deprecated_lightning_data_parallel(): - model = BoringModel() - with pytest.deprecated_call(match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."): - dp_model = LightningDataParallel(model, device_ids=[0]) - assert isinstance(dp_model, torch.nn.DataParallel) - assert isinstance(dp_model.module, LightningParallelModule) - - def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir): class TestModel(BoringModel):