diff --git a/tests/test_models.py b/tests/test_models.py index 1a3156250d..bdd31a5c75 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -115,7 +115,7 @@ def test_model_saving_loading(): # make prediction # assert that both predictions are the same new_pred = model_2(x) - assert torch.eq(pred_before_saving, new_pred) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 clear_save_dir()