Add ddp launcher for ddp testing
This commit is contained in:
parent
3a3eaa5e0c
commit
b44dd7507d
|
@ -17,6 +17,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.backends.launcher import DDPLauncher
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -86,7 +87,10 @@ def test_get_model_gpu(tmpdir):
|
|||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_get_model_ddp_gpu(tmpdir):
|
||||
@DDPLauncher.run("--accelerator [accelerator]",
|
||||
max_epochs=["1"],
|
||||
accelerator=["ddp", "ddp_spawn"])
|
||||
def test_get_model_ddp_gpu(tmpdir, args=None):
|
||||
"""
|
||||
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators
|
||||
"""
|
||||
|
@ -100,6 +104,7 @@ def test_get_model_ddp_gpu(tmpdir):
|
|||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
gpus=1,
|
||||
accelerator='ddp_spawn'
|
||||
accelerator=args.accelerator
|
||||
)
|
||||
trainer.fit(model)
|
||||
return 1
|
||||
|
|
Loading…
Reference in New Issue