[bug/feat] Support parameters_to_ignore in DDP (#7239)

* update

* update

* update

* update on comments

* update
This commit is contained in:
thomas chaton 2021-04-27 18:49:32 +01:00 committed by GitHub
parent 7fe8d18477
commit 5a113a2f05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 0 deletions

View File

@ -367,6 +367,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228))
- Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -36,6 +36,9 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
super().__init__()
self.module = pl_module
# set the parameters_to_ignore from LightningModule.
self._ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
def forward(self, *inputs, **kwargs):
trainer = self.module.trainer

View File

@ -18,8 +18,11 @@ from unittest.mock import patch
import pytest
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from tests.accelerators import ddp_model
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -117,3 +120,35 @@ def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available
)
with pytest.raises(SystemExit):
trainer.fit(model)
@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
def test_ddp_wrapper(tmpdir):
"""
Test parameters to ignore are carried over for DDP.
"""
class WeirdModule(torch.nn.Module):
def _save_to_state_dict(self, destination, prefix, keep_vars):
return {"something": "something"}
class CustomModel(BoringModel):
def __init__(self):
super().__init__()
self.weird_module = WeirdModule()
# should be skip.
self._ddp_params_and_buffers_to_ignore = ('something')
class CustomCallback(Callback):
def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel)
assert trainer.training_type_plugin.model.parameters_to_ignore == ('something')
assert trainer.training_type_plugin.model.module._ddp_params_and_buffers_to_ignore == ('something')
model = CustomModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="ddp", gpus=2, callbacks=CustomCallback())
trainer.fit(model)