diff --git a/tests/test_models.py b/tests/test_models.py index ce1697a91e..eeab97f00d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,6 +27,34 @@ np.random.seed(SEED) # TESTS # ------------------------------------------------------------------------ +def test_amp_single_gpu(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a node with 2+ GPUs to run this test') + return + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0], + distributed_backend='dp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + + def test_cpu_restore_training(): """ Verify continue training session on CPU