diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e6af0503cf..0ebaab553b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -574,7 +574,114 @@ def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization): assert called == expected -def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): +def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): + # initial training to get a checkpoint + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[HookedCallback([])], + ) + trainer.fit(model) + best_model_path = trainer.checkpoint_callback.best_model_path + + called = [] + callback = HookedCallback(called) + # already performed 1 step, resume and do 2 more + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=2, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[callback], + track_grad_norm=1, + ) + assert called == [ + dict(name="Callback.on_init_start", args=(trainer,)), + dict(name="Callback.on_init_end", args=(trainer,)), + ] + + # resume from checkpoint with HookedModel + model = HookedModel(called) + trainer.fit(model, ckpt_path=best_model_path) + loaded_ckpt = { + "callbacks": ANY, + "epoch": 0, + "global_step": 2, + "lr_schedulers": ANY, + "optimizer_states": ANY, + "pytorch-lightning_version": __version__, + "state_dict": ANY, + "loops": ANY, + } + saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0} + saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1} + expected = [ + dict(name="Callback.on_init_start", args=(trainer,)), + dict(name="Callback.on_init_end", args=(trainer,)), + dict(name="configure_callbacks"), + dict(name="prepare_data"), + dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)), + dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), + dict(name="setup", kwargs=dict(stage="fit")), + dict(name="on_load_checkpoint", args=(loaded_ckpt,)), + dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), + dict(name="Callback.load_state_dict", args=({"foo": True},)), + dict(name="configure_sharded_model"), + dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), + dict(name="configure_optimizers"), + dict(name="Callback.on_fit_start", args=(trainer, model)), + dict(name="on_fit_start"), + dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), + dict(name="on_pretrain_routine_start"), + dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), + dict(name="on_pretrain_routine_end"), + dict(name="train", args=(True,)), + dict(name="on_train_dataloader"), + dict(name="train_dataloader"), + dict(name="Callback.on_train_start", args=(trainer, model)), + dict(name="on_train_start"), + dict(name="Callback.on_epoch_start", args=(trainer, model)), + dict(name="on_epoch_start"), + dict(name="Callback.on_train_epoch_start", args=(trainer, model)), + dict(name="on_train_epoch_start"), + dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="Callback.state_dict"), + dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)), + dict(name="on_save_checkpoint", args=(saved_ckpt1,)), + dict(name="on_train_epoch_end"), + dict(name="Callback.on_epoch_end", args=(trainer, model)), + dict(name="on_epoch_end"), + dict(name="Callback.on_epoch_start", args=(trainer, model)), + dict(name="on_epoch_start"), + dict(name="Callback.on_train_epoch_start", args=(trainer, model)), + dict(name="on_train_epoch_start"), + *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), + dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), + dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="Callback.state_dict"), + dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)), + dict(name="on_save_checkpoint", args=(saved_ckpt2,)), + dict(name="on_train_epoch_end"), + dict(name="Callback.on_epoch_end", args=(trainer, model)), + dict(name="on_epoch_end"), + dict(name="Callback.on_train_end", args=(trainer, model)), + dict(name="on_train_end"), + dict(name="Callback.on_fit_end", args=(trainer, model)), + dict(name="on_fit_end"), + dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")), + dict(name="teardown", kwargs=dict(stage="fit")), + ] + assert called == expected + + +def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer(