diff --git a/tests/test_models.py b/tests/test_models.py index 553a5ccc0a..67c677a248 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -190,7 +190,9 @@ def test_amp_gpu_ddp(): os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) - model, hparams = get_model() + hparams = get_hparams() + model = LightningTestModel(hparams) + trainer_options = dict( progress_bar=True, max_nb_epochs=1,