From 3600535bc5be0ec6f5428292811e4b5762b40d07 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 18:16:22 -0400 Subject: [PATCH] fixed correct module on hpc save --- tests/test_models.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index ccd63398e6..59d4d6f525 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,6 +21,36 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + +def test_amp_gpu_ddp(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test') + return + + os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0, 1], + distributed_backend='ddp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + + + def test_cpu_model(): """ Make sure model trains on CPU