diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 2654b329d9..84b23c3230 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -33,11 +33,6 @@ class DDPShardedPlugin(DDPPlugin): self, model: LightningModule, device_ids: List[int] ): self._wrap_optimizers(model) - if model.trainer.testing: # Revert to standard DDP if testing - return super().configure_ddp( - model=model, - device_ids=device_ids - ) return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 11ec6cbe73..87d568a3cd 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -278,3 +278,41 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer.fit(model) return 1 + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model) + return 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test_multigpu(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_spawn', + gpus=2, + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model) + return 1