From b9e0d841dcf4fd5c01f205758b894f4f9c8f77f1 Mon Sep 17 00:00:00 2001 From: williamFalcon Date: Sun, 28 Jul 2019 06:21:41 -0700 Subject: [PATCH] fixed lr scheduler tests --- tests/test_models.py | 58 ++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e8fa339b54..9d40c0da93 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,6 +24,33 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_amp_gpu_ddp(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test') + return + + os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0, 1], + distributed_backend='ddp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + def test_cpu_slurm_save_load(): """ @@ -280,7 +307,7 @@ def test_amp_gpu_ddp_slurm_managed(): if trainer.use_ddp: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers = pretrained_model.configure_optimizers() + trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() # test HPC loading / saving trainer.hpc_save(save_dir, exp) @@ -477,33 +504,6 @@ def test_multi_gpu_model_ddp(): run_gpu_model_test(trainer_options, model, hparams) -def test_amp_gpu_ddp(): - """ - Make sure DDP + AMP work - :return: - """ - if not torch.cuda.is_available(): - warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test') - return - if not torch.cuda.device_count() > 1: - warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test') - return - - os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) - - hparams = get_hparams() - model = LightningTestModel(hparams) - - trainer_options = dict( - progress_bar=True, - max_nb_epochs=1, - gpus=[0, 1], - distributed_backend='ddp', - use_amp=True - ) - - run_gpu_model_test(trainer_options, model, hparams) - def test_ddp_sampler_error(): """ @@ -574,7 +574,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): if trainer.use_ddp: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers = pretrained_model.configure_optimizers() + trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() # test HPC loading / saving trainer.hpc_save(save_dir, exp)