Add hook test for reloading with max epochs (#12932)
This commit is contained in:
parent
456cc87954
commit
26acdd6569
|
@ -574,7 +574,114 @@ def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization):
|
|||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
|
||||
def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
|
||||
# initial training to get a checkpoint
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=0,
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
callbacks=[HookedCallback([])],
|
||||
)
|
||||
trainer.fit(model)
|
||||
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||
|
||||
called = []
|
||||
callback = HookedCallback(called)
|
||||
# already performed 1 step, resume and do 2 more
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=0,
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
callbacks=[callback],
|
||||
track_grad_norm=1,
|
||||
)
|
||||
assert called == [
|
||||
dict(name="Callback.on_init_start", args=(trainer,)),
|
||||
dict(name="Callback.on_init_end", args=(trainer,)),
|
||||
]
|
||||
|
||||
# resume from checkpoint with HookedModel
|
||||
model = HookedModel(called)
|
||||
trainer.fit(model, ckpt_path=best_model_path)
|
||||
loaded_ckpt = {
|
||||
"callbacks": ANY,
|
||||
"epoch": 0,
|
||||
"global_step": 2,
|
||||
"lr_schedulers": ANY,
|
||||
"optimizer_states": ANY,
|
||||
"pytorch-lightning_version": __version__,
|
||||
"state_dict": ANY,
|
||||
"loops": ANY,
|
||||
}
|
||||
saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0}
|
||||
saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1}
|
||||
expected = [
|
||||
dict(name="Callback.on_init_start", args=(trainer,)),
|
||||
dict(name="Callback.on_init_end", args=(trainer,)),
|
||||
dict(name="configure_callbacks"),
|
||||
dict(name="prepare_data"),
|
||||
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
|
||||
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
|
||||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
|
||||
dict(name="Callback.load_state_dict", args=({"foo": True},)),
|
||||
dict(name="configure_sharded_model"),
|
||||
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
|
||||
dict(name="configure_optimizers"),
|
||||
dict(name="Callback.on_fit_start", args=(trainer, model)),
|
||||
dict(name="on_fit_start"),
|
||||
dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
|
||||
dict(name="on_pretrain_routine_start"),
|
||||
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
|
||||
dict(name="on_pretrain_routine_end"),
|
||||
dict(name="train", args=(True,)),
|
||||
dict(name="on_train_dataloader"),
|
||||
dict(name="train_dataloader"),
|
||||
dict(name="Callback.on_train_start", args=(trainer, model)),
|
||||
dict(name="on_train_start"),
|
||||
dict(name="Callback.on_epoch_start", args=(trainer, model)),
|
||||
dict(name="on_epoch_start"),
|
||||
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
|
||||
dict(name="on_train_epoch_start"),
|
||||
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
|
||||
dict(name="Callback.state_dict"),
|
||||
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)),
|
||||
dict(name="on_save_checkpoint", args=(saved_ckpt1,)),
|
||||
dict(name="on_train_epoch_end"),
|
||||
dict(name="Callback.on_epoch_end", args=(trainer, model)),
|
||||
dict(name="on_epoch_end"),
|
||||
dict(name="Callback.on_epoch_start", args=(trainer, model)),
|
||||
dict(name="on_epoch_start"),
|
||||
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
|
||||
dict(name="on_train_epoch_start"),
|
||||
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
|
||||
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
|
||||
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
|
||||
dict(name="Callback.state_dict"),
|
||||
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)),
|
||||
dict(name="on_save_checkpoint", args=(saved_ckpt2,)),
|
||||
dict(name="on_train_epoch_end"),
|
||||
dict(name="Callback.on_epoch_end", args=(trainer, model)),
|
||||
dict(name="on_epoch_end"),
|
||||
dict(name="Callback.on_train_end", args=(trainer, model)),
|
||||
dict(name="on_train_end"),
|
||||
dict(name="Callback.on_fit_end", args=(trainer, model)),
|
||||
dict(name="on_fit_end"),
|
||||
dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")),
|
||||
dict(name="teardown", kwargs=dict(stage="fit")),
|
||||
]
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
|
||||
# initial training to get a checkpoint
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
|
|
Loading…
Reference in New Issue