Avoid wrapping LightningModule in DDP plugins when not fitting (#9096)
* Avoid wrapping LightningModule in DDP plugins when not fitting * Avoid wrapping LightningModule in DDP plugins when not fitting
This commit is contained in:
parent
e2ecb8f859
commit
a451997c4d
|
@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug causing logging with `log_gpu_memory='min_max'` not working ([#9013](https://github.com/PyTorchLightning/pytorch-lightning/pull/9013))
|
||||
|
||||
|
||||
- Fixed wrapping issue: avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#9096]https://github.com/PyTorchLightning/pytorch-lightning/pull/9096)
|
||||
|
||||
|
||||
## [1.4.3] - 2021-08-17
|
||||
|
||||
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
|
||||
|
|
|
@ -56,6 +56,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
)
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
|
@ -361,7 +362,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
|
||||
def configure_ddp(self):
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
|
||||
|
@ -380,7 +381,10 @@ class DDPPlugin(ParallelPlugin):
|
|||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
self.configure_ddp()
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
# share ddp pids to all processes
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
@ -424,17 +428,22 @@ class DDPPlugin(ParallelPlugin):
|
|||
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
|
||||
return tensor
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
def training_step(self, *args, **kwargs) -> Optional[Any]:
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
# used when calling `trainer.fit`
|
||||
return self.model(*args, **kwargs)
|
||||
else:
|
||||
# used when calling `trainer.validate`
|
||||
return self.lightning_module.validation_step(*args, **kwargs)
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
return self.lightning_module.test_step(*args, **kwargs)
|
||||
|
||||
def predict_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def predict_step(self, *args, **kwargs) -> Any:
|
||||
return self.lightning_module.predict_step(*args, **kwargs)
|
||||
|
||||
def post_training_step(self):
|
||||
if not self.lightning_module.automatic_optimization:
|
||||
|
|
|
@ -46,6 +46,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
|
@ -201,7 +202,10 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
self.configure_ddp()
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
self.barrier()
|
||||
|
||||
|
@ -254,7 +258,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
)
|
||||
|
||||
def configure_ddp(self):
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
|
||||
|
@ -340,17 +344,22 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
|
||||
return tensor
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
def training_step(self, *args, **kwargs) -> Optional[Any]:
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
# used when calling `trainer.fit`
|
||||
return self.model(*args, **kwargs)
|
||||
else:
|
||||
# used when calling `trainer.validate`
|
||||
return self.lightning_module.validation_step(*args, **kwargs)
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
return self.lightning_module.test_step(*args, **kwargs)
|
||||
|
||||
def predict_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def predict_step(self, *args, **kwargs) -> Any:
|
||||
return self.lightning_module.predict_step(*args, **kwargs)
|
||||
|
||||
def post_training_step(self):
|
||||
if not self.lightning_module.automatic_optimization:
|
||||
|
|
|
@ -818,3 +818,12 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
@checkpoint_io.setter
|
||||
def checkpoint_io(self, plugin: CheckpointIO) -> None:
|
||||
raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.")
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def predict_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
|
|
@ -34,7 +34,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
|
||||
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
|
||||
|
||||
def configure_ddp(self):
|
||||
def configure_ddp(self) -> None:
|
||||
self._wrap_optimizers()
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model),
|
||||
|
|
|
@ -33,7 +33,7 @@ if _FAIRSCALE_AVAILABLE:
|
|||
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
||||
"""Optimizer sharded training provided by FairScale."""
|
||||
|
||||
def configure_ddp(self):
|
||||
def configure_ddp(self) -> None:
|
||||
self._wrap_optimizers()
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
|
||||
|
|
|
@ -18,9 +18,10 @@ import pytest
|
|||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -94,3 +95,37 @@ def test_incorrect_ddp_script_spawning(tmpdir):
|
|||
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_ddp_configure_ddp():
|
||||
"""Tests with ddp plugin."""
|
||||
model = BoringModel()
|
||||
ddp_plugin = DDPPlugin()
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
plugins=[ddp_plugin],
|
||||
)
|
||||
# test wrap the model if fitting
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
trainer.accelerator.connect(model)
|
||||
trainer.accelerator.setup_environment()
|
||||
trainer.accelerator.setup(trainer)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
assert isinstance(trainer.model, LightningModule)
|
||||
trainer._pre_dispatch()
|
||||
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
|
||||
assert isinstance(trainer.model, DistributedDataParallel)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
plugins=[ddp_plugin],
|
||||
)
|
||||
# test do not wrap the model if trainerFN is not fitting
|
||||
trainer.accelerator.connect(model)
|
||||
trainer.accelerator.setup_environment()
|
||||
trainer.accelerator.setup(trainer)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
trainer._pre_dispatch()
|
||||
# in DDPPlugin configure_ddp(), model are still LightningModule
|
||||
assert isinstance(trainer.model, LightningModule)
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from tests.helpers.boring_model import BoringDataModule, BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -77,3 +79,37 @@ def test_ddp_spawn_extra_parameters(tmpdir):
|
|||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.callback_metrics[val_name] == torch.tensor(val)
|
||||
assert model.test_val == "test_val"
|
||||
|
||||
|
||||
class BoringModelDDP(BoringModel):
|
||||
def on_train_start(self) -> None:
|
||||
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
|
||||
assert isinstance(self.trainer.model, DistributedDataParallel)
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during test stage."""
|
||||
if self.trainer.state.fn == TrainerFn.FITTING:
|
||||
assert isinstance(self.trainer.model, DistributedDataParallel)
|
||||
else:
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during test stage."""
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
def on_predict_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during prediction stage."""
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_ddp_spawn_configure_ddp(tmpdir):
|
||||
"""Tests with ddp spawn plugin."""
|
||||
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)
|
||||
|
||||
model = BoringModelDDP()
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.validate(model, dataloaders=model.val_dataloader())
|
||||
trainer.test(model, dataloaders=model.test_dataloader())
|
||||
trainer.predict(model, dataloaders=model.predict_dataloader())
|
||||
|
|
|
@ -4,13 +4,18 @@ from unittest import mock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("clip_val", [0, 10])
|
||||
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
|
||||
|
@ -249,3 +254,37 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
|
|||
model = ManualBoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
class BoringModelSharded(BoringModel):
|
||||
def on_train_start(self) -> None:
|
||||
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
|
||||
assert isinstance(self.trainer.model, ShardedDataParallel)
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during test stage."""
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during test stage."""
|
||||
if self.trainer.state.fn == TrainerFn.FITTING:
|
||||
assert isinstance(self.trainer.model, ShardedDataParallel)
|
||||
else:
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
def on_predict_start(self) -> None:
|
||||
"""Check if trainer module remains as LightningModule during prediction stage."""
|
||||
assert isinstance(self.trainer.model, LightningModule)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, fairscale=True)
|
||||
def test_configure_ddp(tmpdir):
|
||||
"""Tests with ddp sharded plugin."""
|
||||
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True)
|
||||
|
||||
model = BoringModelSharded()
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.test(model, dataloaders=model.test_dataloader())
|
||||
trainer.validate(model, dataloaders=model.val_dataloader())
|
||||
trainer.predict(model, dataloaders=model.predict_dataloader())
|
||||
|
|
Loading…
Reference in New Issue