fixed correct module on hpc save

This commit is contained in:
William Falcon 2019-07-24 18:16:22 -04:00
parent d7be0aae1c
commit 3600535bc5
1 changed files with 30 additions and 0 deletions

View File

@ -21,6 +21,36 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_amp_gpu_ddp():
"""
Make sure DDP + AMP work
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
return
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='ddp',
use_amp=True
)
run_gpu_model_test(trainer_options, model, hparams)
def test_cpu_model():
"""
Make sure model trains on CPU