Removed old eval logic, added eval tests
This commit is contained in:
parent
2e50c2e653
commit
ab655e5118
|
@ -33,11 +33,6 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
self, model: LightningModule, device_ids: List[int]
|
||||
):
|
||||
self._wrap_optimizers(model)
|
||||
if model.trainer.testing: # Revert to standard DDP if testing
|
||||
return super().configure_ddp(
|
||||
model=model,
|
||||
device_ids=device_ids
|
||||
)
|
||||
return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers)
|
||||
|
||||
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
|
||||
|
|
|
@ -278,3 +278,41 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir):
|
|||
|
||||
trainer.fit(model)
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_test(tmpdir):
|
||||
"""
|
||||
Test to ensure we can use test without fit
|
||||
"""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
accelerator='ddp_cpu',
|
||||
plugins=[DDPShardedPlugin()],
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_test_multigpu(tmpdir):
|
||||
"""
|
||||
Test to ensure we can use test without fit
|
||||
"""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
accelerator='ddp_spawn',
|
||||
gpus=2,
|
||||
plugins=[DDPShardedPlugin()],
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
||||
return 1
|
||||
|
|
Loading…
Reference in New Issue