running ddp tests
This commit is contained in:
parent
1313a7f397
commit
3451a62650
|
@ -22,35 +22,6 @@ np.random.seed(SEED)
|
|||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def test_amp_gpu_ddp():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
|
||||
return
|
||||
if not torch.cuda.device_count() > 1:
|
||||
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
|
||||
def test_cpu_model():
|
||||
"""
|
||||
Make sure model trains on CPU
|
||||
|
|
Loading…
Reference in New Issue