diff --git a/tests/test_models.py b/tests/test_models.py index 14f1e4acde..9142ca452d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -42,7 +42,7 @@ def test_cpu_restore_training(): exp.save() trainer_options = dict( - max_nb_epochs=1, + max_nb_epochs=2, val_check_interval=0.50, val_percent_check=0.2, train_percent_check=0.2, @@ -53,7 +53,7 @@ def test_cpu_restore_training(): # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) - real_global_step = trainer.global_step + real_global_epoch = trainer.current_epoch # traning complete assert result == 1, 'amp + ddp model failed to complete' @@ -86,7 +86,7 @@ def test_cpu_restore_training(): # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): - assert trainer.global_step == real_global_step and trainer.global_step > 0 + assert trainer.current_epoch == real_global_epoch and trainer.real_global_epoch > 0 # predict with loaded model to make sure answers are the same trainer.model.eval()