diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index c7a28f25d9..211dc49d42 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -446,6 +446,7 @@ def test_load_model_from_checkpoint(tmp_path, model_template): "limit_test_batches": 2, "callbacks": [ModelCheckpoint(dirpath=tmp_path, monitor="val_loss", save_top_k=-1)], "default_root_dir": tmp_path, + "accelerator": "cpu", } # fit model