Test `Callback.on_load_checkpoint` order (#8588)

This commit is contained in:
Carlos Mocholí 2021-07-29 12:28:29 +02:00 committed by GitHub
parent 7901d297d3
commit c99e2fe0d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 3 deletions

View File

@ -213,16 +213,22 @@ def get_members(cls):
class HookedCallback(Callback):
def __init__(self, called):
def call(hook, *args, **kwargs):
def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
d = {"name": f"Callback.{hook}"}
if args:
d["args"] = args
if kwargs:
d["kwargs"] = kwargs
called.append(d)
return out
for h in get_members(Callback):
setattr(self, h, partial(call, h))
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
def on_save_checkpoint(*args, **kwargs):
return {"foo": True}
class HookedModel(BoringModel):
@ -555,7 +561,12 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, max_steps=1, limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=0,
progress_bar_refresh_rate=0,
weights_summary=None,
callbacks=[HookedCallback([])],
)
trainer.fit(model)
best_model_path = trainer.checkpoint_callback.best_model_path
@ -611,6 +622,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
},
),
),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
dict(name="configure_optimizers"),