Test `Callback.on_load_checkpoint` order (#8588)
This commit is contained in:
parent
7901d297d3
commit
c99e2fe0d2
|
@ -213,16 +213,22 @@ def get_members(cls):
|
|||
|
||||
class HookedCallback(Callback):
|
||||
def __init__(self, called):
|
||||
def call(hook, *args, **kwargs):
|
||||
def call(hook, fn, *args, **kwargs):
|
||||
out = fn(*args, **kwargs)
|
||||
d = {"name": f"Callback.{hook}"}
|
||||
if args:
|
||||
d["args"] = args
|
||||
if kwargs:
|
||||
d["kwargs"] = kwargs
|
||||
called.append(d)
|
||||
return out
|
||||
|
||||
for h in get_members(Callback):
|
||||
setattr(self, h, partial(call, h))
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
|
||||
def on_save_checkpoint(*args, **kwargs):
|
||||
return {"foo": True}
|
||||
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
|
@ -555,7 +561,12 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
|
|||
# initial training to get a checkpoint
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, max_steps=1, limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=1,
|
||||
limit_val_batches=0,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
callbacks=[HookedCallback([])],
|
||||
)
|
||||
trainer.fit(model)
|
||||
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||
|
@ -611,6 +622,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
|
|||
},
|
||||
),
|
||||
),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
|
||||
dict(name="configure_sharded_model"),
|
||||
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
|
||||
dict(name="configure_optimizers"),
|
||||
|
|
Loading…
Reference in New Issue