diff --git a/tests/test_models.py b/tests/test_models.py index 575ceb0551..d3e694fc43 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -86,7 +86,7 @@ def test_cpu_restore_training(): # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): - assert trainer.current_epoch == real_global_epoch and trainer.real_global_epoch > 0 + assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 # predict with loaded model to make sure answers are the same trainer.model.eval()