From 9e187574dec463f43c62a5560164bddc2486103e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 14:17:36 -0400 Subject: [PATCH] fixed amp bug --- tests/test_models.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index be9db38658..96e9915698 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -23,24 +23,16 @@ def test_cpu_model(): Make sure model trains on CPU :return: """ - save_dir = init_save_dir() - model, hparams = get_model() - - trainer = Trainer( + trainer_options = dict( progress_bar=False, experiment=get_exp(), max_nb_epochs=1, train_percent_check=0.4, val_percent_check=0.4 ) - result = trainer.fit(model) - # correct result and ok accuracy - assert result == 1, 'cpu model failed to complete' - assert_ok_acc(trainer) - - clear_save_dir() + run_gpu_model_test(trainer_options, on_gpu=False) def test_single_gpu_model(): @@ -162,7 +154,7 @@ def test_amp_gpu_ddp(): # UTILS # ------------------------------------------------------------------------ -def run_gpu_model_test(trainer_options): +def run_gpu_model_test(trainer_options, on_gpu=True): """ Make sure DDP + AMP work :return: @@ -197,7 +189,7 @@ def run_gpu_model_test(trainer_options): assert result == 1, 'amp + ddp model failed to complete' # test model loading - pretrained_model = load_model(exp, save_dir) + pretrained_model = load_model(exp, save_dir, on_gpu) # test model preds run_prediction(model.test_dataloader, pretrained_model) @@ -247,7 +239,7 @@ def clear_save_dir(): shutil.rmtree(save_dir) -def load_model(exp, save_dir): +def load_model(exp, save_dir, on_gpu): # load trained model tags_path = exp.get_data_path(exp.name, exp.version) @@ -256,7 +248,7 @@ def load_model(exp, save_dir): checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x] weights_dir = os.path.join(save_dir, checkpoints[0]) - trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True) + trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=on_gpu) assert trained_model is not None, 'loading model failed'