diff --git a/tests/test_models.py b/tests/test_models.py index 84383b453a..15a41e9c8e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,6 +21,28 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + + +def test_cpu_model_with_amp(): + """ + Make sure model trains on CPU + :return: + """ + + trainer_options = dict( + progress_bar=False, + experiment=get_exp(), + max_nb_epochs=1, + train_percent_check=0.4, + val_percent_check=0.4, + use_amp=True + ) + + model, hparams = get_model() + + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + def test_amp_gpu_ddp_slurm_managed(): """ Make sure DDP + AMP work @@ -111,26 +133,6 @@ def test_cpu_model(): run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) -def test_cpu_model_with_amp(): - """ - Make sure model trains on CPU - :return: - """ - - trainer_options = dict( - progress_bar=False, - experiment=get_exp(), - max_nb_epochs=1, - train_percent_check=0.4, - val_percent_check=0.4, - use_amp=True - ) - - model, hparams = get_model() - - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - def test_all_features_cpu_model(): """ Test each of the trainer options