diff --git a/tests/test_models.py b/tests/test_models.py index 04ca68c65b..ccd63398e6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -277,7 +277,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): # test HPC loading / saving trainer.hpc_save(save_dir, exp) - trainer.hpc_load(save_dir, on_gpu=True) + trainer.hpc_load(save_dir, on_gpu=on_gpu) clear_save_dir()