From fe0d08899eba94d275ff42253f495d9e70d86f89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 10 Apr 2021 18:14:37 +0200 Subject: [PATCH] Fix ShardedDataParallel has no attribute require_backward_grad_sync (#6915) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- CHANGELOG.md | 3 ++ .../plugins/training_type/sharded.py | 10 +++++++ .../plugins/training_type/sharded_spawn.py | 10 +++++++ tests/plugins/test_sharded_plugin.py | 29 +++++++++++++++++++ 4 files changed, 52 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e4ca17b59..0f96664db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) +- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 7536ef9b1d..eaa9331d28 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -13,6 +13,9 @@ # limitations under the License. from typing import Optional +import torch +from torch.optim import Optimizer + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -33,6 +36,7 @@ class DDPShardedPlugin(DDPPlugin): self._model = ShardedDataParallel( LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers ) + setattr(self._model, "require_backward_grad_sync", False) def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers @@ -70,3 +74,9 @@ class DDPShardedPlugin(DDPPlugin): @property def lightning_module(self) -> LightningModule: return unwrap_lightning_module_sharded(self._model) + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + pass + + def post_training_step(self): + pass diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 7aadf797e1..dc1e5e8384 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -13,6 +13,9 @@ # limitations under the License. from typing import Optional +import torch +from torch.optim import Optimizer + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerState @@ -32,6 +35,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): self._model = ShardedDataParallel( LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers ) + setattr(self._model, "require_backward_grad_sync", False) def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers @@ -65,3 +69,9 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): @property def lightning_module(self) -> LightningModule: return unwrap_lightning_module_sharded(self._model) + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + pass + + def post_training_step(self): + pass diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 655e12f046..7ab49e6826 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -278,3 +278,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): trainer.validate(model) trainer.test(model) + + +class ManualBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + output = self(batch) + loss = self.loss(batch, output) + self.manual_backward(loss) + opt.step() + return {"loss": loss} + + +@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2) +@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"]) +def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator): + model = ManualBoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + accelerator=accelerator, + fast_dev_run=2, + gpus=2, + ) + trainer.fit(model)