Avoid wrapping LightningModule in DDP plugins when not fitting (#9096)

* Avoid wrapping LightningModule in DDP plugins when not fitting

* Avoid wrapping LightningModule in DDP plugins when not fitting
This commit is contained in:
four4fish 2021-09-01 19:23:59 -07:00 committed by GitHub
parent e2ecb8f859
commit a451997c4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 163 additions and 23 deletions

View File

@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug causing logging with `log_gpu_memory='min_max'` not working ([#9013](https://github.com/PyTorchLightning/pytorch-lightning/pull/9013))
- Fixed wrapping issue: avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#9096]https://github.com/PyTorchLightning/pytorch-lightning/pull/9096)
## [1.4.3] - 2021-08-17
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))

View File

@ -56,6 +56,7 @@ from pytorch_lightning.utilities.distributed import (
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
@ -361,7 +362,7 @@ class DDPPlugin(ParallelPlugin):
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
def configure_ddp(self):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
@ -380,7 +381,10 @@ class DDPPlugin(ParallelPlugin):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
# share ddp pids to all processes
self._share_information_to_prevent_deadlock()
@ -424,17 +428,22 @@ class DDPPlugin(ParallelPlugin):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def training_step(self, *args, **kwargs):
def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)
def post_training_step(self):
if not self.lightning_module.automatic_optimization:

View File

@ -46,6 +46,7 @@ from pytorch_lightning.utilities.distributed import (
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
@ -201,7 +202,10 @@ class DDPSpawnPlugin(ParallelPlugin):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
self.barrier()
@ -254,7 +258,7 @@ class DDPSpawnPlugin(ParallelPlugin):
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
def configure_ddp(self):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
@ -340,17 +344,22 @@ class DDPSpawnPlugin(ParallelPlugin):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def training_step(self, *args, **kwargs):
def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)
def post_training_step(self):
if not self.lightning_module.automatic_optimization:

View File

@ -818,3 +818,12 @@ class DeepSpeedPlugin(DDPPlugin):
@checkpoint_io.setter
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.")
def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

View File

@ -34,7 +34,7 @@ class DDPShardedPlugin(DDPPlugin):
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
def configure_ddp(self):
def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),

View File

@ -33,7 +33,7 @@ if _FAIRSCALE_AVAILABLE:
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""
def configure_ddp(self):
def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers

View File

@ -18,9 +18,10 @@ import pytest
import torch
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -94,3 +95,37 @@ def test_incorrect_ddp_script_spawning(tmpdir):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)
@RunIf(skip_windows=True)
def test_ddp_configure_ddp():
"""Tests with ddp plugin."""
model = BoringModel()
ddp_plugin = DDPPlugin()
trainer = Trainer(
max_epochs=1,
plugins=[ddp_plugin],
)
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
assert isinstance(trainer.model, LightningModule)
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
assert isinstance(trainer.model, DistributedDataParallel)
trainer = Trainer(
max_epochs=1,
plugins=[ddp_plugin],
)
# test do not wrap the model if trainerFN is not fitting
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model are still LightningModule
assert isinstance(trainer.model, LightningModule)

View File

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
@ -77,3 +79,37 @@ def test_ddp_spawn_extra_parameters(tmpdir):
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"
class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
assert isinstance(self.trainer.model, DistributedDataParallel)
def on_validation_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
if self.trainer.state.fn == TrainerFn.FITTING:
assert isinstance(self.trainer.model, DistributedDataParallel)
else:
assert isinstance(self.trainer.model, LightningModule)
def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)
def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)
@RunIf(skip_windows=True)
def test_ddp_spawn_configure_ddp(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)
model = BoringModelDDP()
trainer.fit(model)
trainer.validate(model, dataloaders=model.val_dataloader())
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())

View File

@ -4,13 +4,18 @@ from unittest import mock
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
@ -249,3 +254,37 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
model = ManualBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
trainer.fit(model)
class BoringModelSharded(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
assert isinstance(self.trainer.model, ShardedDataParallel)
def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)
def on_validation_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
if self.trainer.state.fn == TrainerFn.FITTING:
assert isinstance(self.trainer.model, ShardedDataParallel)
else:
assert isinstance(self.trainer.model, LightningModule)
def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)
@RunIf(skip_windows=True, fairscale=True)
def test_configure_ddp(tmpdir):
"""Tests with ddp sharded plugin."""
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True)
model = BoringModelSharded()
trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.validate(model, dataloaders=model.val_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())