Fix ShardedDataParallel has no attribute require_backward_grad_sync (#6915)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-04-10 18:14:37 +02:00 committed by GitHub
parent 20ff50caa6
commit fe0d08899e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 0 deletions

View File

@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -13,6 +13,9 @@
# limitations under the License.
from typing import Optional
import torch
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@ -33,6 +36,7 @@ class DDPShardedPlugin(DDPPlugin):
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
)
setattr(self._model, "require_backward_grad_sync", False)
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
@ -70,3 +74,9 @@ class DDPShardedPlugin(DDPPlugin):
@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_sharded(self._model)
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
pass
def post_training_step(self):
pass

View File

@ -13,6 +13,9 @@
# limitations under the License.
from typing import Optional
import torch
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
@ -32,6 +35,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
)
setattr(self._model, "require_backward_grad_sync", False)
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
@ -65,3 +69,9 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_sharded(self._model)
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
pass
def post_training_step(self):
pass

View File

@ -278,3 +278,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs):
trainer.validate(model)
trainer.test(model)
class ManualBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
output = self(batch)
loss = self.loss(batch, output)
self.manual_backward(loss)
opt.step()
return {"loss": loss}
@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2)
@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"])
def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator):
model = ManualBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=accelerator,
fast_dev_run=2,
gpus=2,
)
trainer.fit(model)