From a451997c4da89be3b1e4f7f79b52015bd32f2ea4 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 1 Sep 2021 19:23:59 -0700 Subject: [PATCH] Avoid wrapping LightningModule in DDP plugins when not fitting (#9096) * Avoid wrapping LightningModule in DDP plugins when not fitting * Avoid wrapping LightningModule in DDP plugins when not fitting --- CHANGELOG.md | 3 ++ .../plugins/training_type/ddp.py | 27 ++++++++---- .../plugins/training_type/ddp_spawn.py | 27 ++++++++---- .../plugins/training_type/deepspeed.py | 9 ++++ .../plugins/training_type/sharded.py | 2 +- .../plugins/training_type/sharded_spawn.py | 2 +- tests/plugins/test_ddp_plugin.py | 37 ++++++++++++++++- tests/plugins/test_ddp_spawn_plugin.py | 38 ++++++++++++++++- tests/plugins/test_sharded_plugin.py | 41 ++++++++++++++++++- 9 files changed, 163 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d86f5ac2e9..3913552145 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug causing logging with `log_gpu_memory='min_max'` not working ([#9013](https://github.com/PyTorchLightning/pytorch-lightning/pull/9013)) +- Fixed wrapping issue: avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#9096]https://github.com/PyTorchLightning/pytorch-lightning/pull/9096) + + ## [1.4.3] - 2021-08-17 - Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861)) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6d96a443e3..2396670a49 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -56,6 +56,7 @@ from pytorch_lightning.utilities.distributed import ( ) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_10: from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer @@ -361,7 +362,7 @@ class DDPPlugin(ParallelPlugin): trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers() - def configure_ddp(self): + def configure_ddp(self) -> None: self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs @@ -380,7 +381,10 @@ class DDPPlugin(ParallelPlugin): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) - self.configure_ddp() + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = self.lightning_module.trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() # share ddp pids to all processes self._share_information_to_prevent_deadlock() @@ -424,17 +428,22 @@ class DDPPlugin(ParallelPlugin): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def training_step(self, *args, **kwargs): + def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) - def validation_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + if isinstance(self.model, DistributedDataParallel): + # used when calling `trainer.fit` + return self.model(*args, **kwargs) + else: + # used when calling `trainer.validate` + return self.lightning_module.validation_step(*args, **kwargs) - def test_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + return self.lightning_module.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def predict_step(self, *args, **kwargs) -> Any: + return self.lightning_module.predict_step(*args, **kwargs) def post_training_step(self): if not self.lightning_module.automatic_optimization: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c31a908902..a45e70adff 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -46,6 +46,7 @@ from pytorch_lightning.utilities.distributed import ( sync_ddp_if_available, ) from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -201,7 +202,10 @@ class DDPSpawnPlugin(ParallelPlugin): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) - self.configure_ddp() + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = self.lightning_module.trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() self.barrier() @@ -254,7 +258,7 @@ class DDPSpawnPlugin(ParallelPlugin): ddp_comm_wrapper=self._ddp_comm_wrapper, ) - def configure_ddp(self): + def configure_ddp(self) -> None: self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs @@ -340,17 +344,22 @@ class DDPSpawnPlugin(ParallelPlugin): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def training_step(self, *args, **kwargs): + def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) - def validation_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + if isinstance(self.model, DistributedDataParallel): + # used when calling `trainer.fit` + return self.model(*args, **kwargs) + else: + # used when calling `trainer.validate` + return self.lightning_module.validation_step(*args, **kwargs) - def test_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + return self.lightning_module.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + def predict_step(self, *args, **kwargs) -> Any: + return self.lightning_module.predict_step(*args, **kwargs) def post_training_step(self): if not self.lightning_module.automatic_optimization: diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 89851fafef..5fa8739de7 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -818,3 +818,12 @@ class DeepSpeedPlugin(DDPPlugin): @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.") + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index a9208233db..8d05432810 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -34,7 +34,7 @@ class DDPShardedPlugin(DDPPlugin): _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M - def configure_ddp(self): + def configure_ddp(self) -> None: self._wrap_optimizers() self._model = ShardedDataParallel( LightningShardedDataParallel(self.model), diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 4a4900b03b..ce96d43a9b 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -33,7 +33,7 @@ if _FAIRSCALE_AVAILABLE: class DDPSpawnShardedPlugin(DDPSpawnPlugin): """Optimizer sharded training provided by FairScale.""" - def configure_ddp(self): + def configure_ddp(self) -> None: self._wrap_optimizers() self._model = ShardedDataParallel( LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 60ec193085..bd13275e9e 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -18,9 +18,10 @@ import pytest import torch from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -94,3 +95,37 @@ def test_incorrect_ddp_script_spawning(tmpdir): RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." ): trainer.fit(model) + + +@RunIf(skip_windows=True) +def test_ddp_configure_ddp(): + """Tests with ddp plugin.""" + model = BoringModel() + ddp_plugin = DDPPlugin() + trainer = Trainer( + max_epochs=1, + plugins=[ddp_plugin], + ) + # test wrap the model if fitting + trainer.state.fn = TrainerFn.FITTING + trainer.accelerator.connect(model) + trainer.accelerator.setup_environment() + trainer.accelerator.setup(trainer) + trainer.lightning_module.trainer = trainer + assert isinstance(trainer.model, LightningModule) + trainer._pre_dispatch() + # in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel + assert isinstance(trainer.model, DistributedDataParallel) + + trainer = Trainer( + max_epochs=1, + plugins=[ddp_plugin], + ) + # test do not wrap the model if trainerFN is not fitting + trainer.accelerator.connect(model) + trainer.accelerator.setup_environment() + trainer.accelerator.setup(trainer) + trainer.lightning_module.trainer = trainer + trainer._pre_dispatch() + # in DDPPlugin configure_ddp(), model are still LightningModule + assert isinstance(trainer.model, LightningModule) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 1ab94446c8..2d987b0788 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from torch.nn.parallel.distributed import DistributedDataParallel -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -77,3 +79,37 @@ def test_ddp_spawn_extra_parameters(tmpdir): trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val" + + +class BoringModelDDP(BoringModel): + def on_train_start(self) -> None: + """Check if trainer module is wrapped as DistributedDataParallel during training stage.""" + assert isinstance(self.trainer.model, DistributedDataParallel) + + def on_validation_start(self) -> None: + """Check if trainer module remains as LightningModule during test stage.""" + if self.trainer.state.fn == TrainerFn.FITTING: + assert isinstance(self.trainer.model, DistributedDataParallel) + else: + assert isinstance(self.trainer.model, LightningModule) + + def on_test_start(self) -> None: + """Check if trainer module remains as LightningModule during test stage.""" + assert isinstance(self.trainer.model, LightningModule) + + def on_predict_start(self) -> None: + """Check if trainer module remains as LightningModule during prediction stage.""" + assert isinstance(self.trainer.model, LightningModule) + + +@RunIf(skip_windows=True) +def test_ddp_spawn_configure_ddp(tmpdir): + """Tests with ddp spawn plugin.""" + trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True) + + model = BoringModelDDP() + + trainer.fit(model) + trainer.validate(model, dataloaders=model.val_dataloader()) + trainer.test(model, dataloaders=model.test_dataloader()) + trainer.predict(model, dataloaders=model.predict_dataloader()) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 82e899a6f4..3bcdc357fc 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -4,13 +4,18 @@ from unittest import mock import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + @pytest.mark.parametrize("clip_val", [0, 10]) @RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) @@ -249,3 +254,37 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir): model = ManualBoringModel() trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2) trainer.fit(model) + + +class BoringModelSharded(BoringModel): + def on_train_start(self) -> None: + """Check if trainer module is wrapped as ShardedDataParallel during training stage.""" + assert isinstance(self.trainer.model, ShardedDataParallel) + + def on_test_start(self) -> None: + """Check if trainer module remains as LightningModule during test stage.""" + assert isinstance(self.trainer.model, LightningModule) + + def on_validation_start(self) -> None: + """Check if trainer module remains as LightningModule during test stage.""" + if self.trainer.state.fn == TrainerFn.FITTING: + assert isinstance(self.trainer.model, ShardedDataParallel) + else: + assert isinstance(self.trainer.model, LightningModule) + + def on_predict_start(self) -> None: + """Check if trainer module remains as LightningModule during prediction stage.""" + assert isinstance(self.trainer.model, LightningModule) + + +@RunIf(skip_windows=True, fairscale=True) +def test_configure_ddp(tmpdir): + """Tests with ddp sharded plugin.""" + trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True) + + model = BoringModelSharded() + + trainer.fit(model) + trainer.test(model, dataloaders=model.test_dataloader()) + trainer.validate(model, dataloaders=model.val_dataloader()) + trainer.predict(model, dataloaders=model.predict_dataloader())