Add hook test for reloading with max epochs (#12932)

This commit is contained in:
Carlos Mocholí 2022-05-02 14:41:28 +02:00 committed by GitHub
parent 456cc87954
commit 26acdd6569
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 108 additions and 1 deletions

View File

@ -574,7 +574,114 @@ def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization):
assert called == expected
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
callbacks=[HookedCallback([])],
)
trainer.fit(model)
best_model_path = trainer.checkpoint_callback.best_model_path
called = []
callback = HookedCallback(called)
# already performed 1 step, resume and do 2 more
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=2,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
callbacks=[callback],
track_grad_norm=1,
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
# resume from checkpoint with HookedModel
model = HookedModel(called)
trainer.fit(model, ckpt_path=best_model_path)
loaded_ckpt = {
"callbacks": ANY,
"epoch": 0,
"global_step": 2,
"lr_schedulers": ANY,
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
}
saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0}
saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1}
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
dict(name="Callback.load_state_dict", args=({"foo": True},)),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
dict(name="configure_optimizers"),
dict(name="Callback.on_fit_start", args=(trainer, model)),
dict(name="on_fit_start"),
dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
dict(name="on_pretrain_routine_start"),
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
dict(name="on_pretrain_routine_end"),
dict(name="train", args=(True,)),
dict(name="on_train_dataloader"),
dict(name="train_dataloader"),
dict(name="Callback.on_train_start", args=(trainer, model)),
dict(name="on_train_start"),
dict(name="Callback.on_epoch_start", args=(trainer, model)),
dict(name="on_epoch_start"),
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
dict(name="on_train_epoch_start"),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)),
dict(name="on_save_checkpoint", args=(saved_ckpt1,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
dict(name="Callback.on_epoch_start", args=(trainer, model)),
dict(name="on_epoch_start"),
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
dict(name="on_train_epoch_start"),
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)),
dict(name="on_save_checkpoint", args=(saved_ckpt2,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
dict(name="on_train_end"),
dict(name="Callback.on_fit_end", args=(trainer, model)),
dict(name="on_fit_end"),
dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="teardown", kwargs=dict(stage="fit")),
]
assert called == expected
def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(