Fix ShardedDataParallel has no attribute require_backward_grad_sync (#6915)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
20ff50caa6
commit
fe0d08899e
|
@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
|
||||
|
||||
|
||||
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -13,6 +13,9 @@
|
|||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
|
@ -33,6 +36,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
|
||||
def _reinit_optimizers_with_oss(self):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
|
@ -70,3 +74,9 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
@property
|
||||
def lightning_module(self) -> LightningModule:
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
|
||||
pass
|
||||
|
||||
def post_training_step(self):
|
||||
pass
|
||||
|
|
|
@ -13,6 +13,9 @@
|
|||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
|
@ -32,6 +35,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
|
||||
def _reinit_optimizers_with_oss(self):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
|
@ -65,3 +69,9 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
@property
|
||||
def lightning_module(self) -> LightningModule:
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
|
||||
pass
|
||||
|
||||
def post_training_step(self):
|
||||
pass
|
||||
|
|
|
@ -278,3 +278,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs):
|
|||
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
class ManualBoringModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
opt = self.optimizers()
|
||||
opt.zero_grad()
|
||||
output = self(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
return {"loss": loss}
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2)
|
||||
@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"])
|
||||
def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator):
|
||||
model = ManualBoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
accelerator=accelerator,
|
||||
fast_dev_run=2,
|
||||
gpus=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue