[bug/feat] Support parameters_to_ignore in DDP (#7239)
* update * update * update * update on comments * update
This commit is contained in:
parent
7fe8d18477
commit
5a113a2f05
|
@ -367,6 +367,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228))
|
||||
|
||||
|
||||
- Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -36,6 +36,9 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
super().__init__()
|
||||
self.module = pl_module
|
||||
|
||||
# set the parameters_to_ignore from LightningModule.
|
||||
self._ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
trainer = self.module.trainer
|
||||
|
||||
|
|
|
@ -18,8 +18,11 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from tests.accelerators import ddp_model
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -117,3 +120,35 @@ def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available
|
|||
)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
|
||||
def test_ddp_wrapper(tmpdir):
|
||||
"""
|
||||
Test parameters to ignore are carried over for DDP.
|
||||
"""
|
||||
|
||||
class WeirdModule(torch.nn.Module):
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
return {"something": "something"}
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weird_module = WeirdModule()
|
||||
|
||||
# should be skip.
|
||||
self._ddp_params_and_buffers_to_ignore = ('something')
|
||||
|
||||
class CustomCallback(Callback):
|
||||
|
||||
def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
|
||||
assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel)
|
||||
assert trainer.training_type_plugin.model.parameters_to_ignore == ('something')
|
||||
assert trainer.training_type_plugin.model.module._ddp_params_and_buffers_to_ignore == ('something')
|
||||
|
||||
model = CustomModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="ddp", gpus=2, callbacks=CustomCallback())
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue