diff --git a/tests/test_models.py b/tests/test_models.py index 768acd05b7..0bf8d35078 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -182,6 +182,7 @@ def test_cpu_slurm_save_load(): assert trainer.global_step == real_global_step and trainer.global_step != 20000000 # predict with loaded model to make sure answers are the same + trainer.model.eval() new_pred = trainer.model(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1