Fix failing master due to an interction between PRs (#10627)

This commit is contained in:
Carlos Mocholí 2021-11-19 03:04:53 +01:00 committed by GitHub
parent 35f6cbe09f
commit 0de8ab4f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 10 deletions

View File

@ -485,8 +485,7 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"):
trainer.fit(model)
trainer.fit(model)
saved_ckpt = {
"callbacks": ANY,
"epoch": 1,
@ -588,8 +587,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
enable_model_summary=False,
callbacks=[HookedCallback([])],
)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.fit(model)
trainer.fit(model)
best_model_path = trainer.checkpoint_callback.best_model_path
# resume from checkpoint with HookedModel
@ -611,8 +609,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"):
trainer.fit(model, ckpt_path=best_model_path)
trainer.fit(model, ckpt_path=best_model_path)
saved_ckpt = {
"callbacks": ANY,
"epoch": 2, # TODO: wrong saved epoch
@ -707,8 +704,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
dict(name="Callback.on_init_end", args=(trainer,)),
]
fn = getattr(trainer, verb)
with pytest.deprecated_call(match=f"on_{dataloader}_dataloader` is deprecated in v1.5"):
fn(model, verbose=False)
fn(model, verbose=False)
hooks = [
dict(name="train", args=(False,)),
dict(name=f"on_{noun}_model_eval"),
@ -752,8 +748,7 @@ def test_trainer_model_hook_system_predict(tmpdir):
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"):
trainer.predict(model)
trainer.predict(model)
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),