Use non-deprecated options in tests (#9949)

This commit is contained in:
Carlos Mocholí 2021-10-16 01:58:07 +02:00 committed by GitHub
parent db4e770004
commit e973bcb76a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 15 deletions

View File

@ -46,7 +46,7 @@ def test_device_stats_gpu_from_torch(tmpdir):
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
@ -75,7 +75,7 @@ def test_device_stats_gpu_from_nvidia(tmpdir):
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
@ -104,7 +104,7 @@ def test_device_stats_monitor_tpu(tmpdir):
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
@ -122,7 +122,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
callbacks=[device_stats],
max_epochs=1,
logger=False,
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)

View File

@ -32,8 +32,8 @@ def test_model_summary_callback_present_trainer():
def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(weights_summary=None)
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
trainer = Trainer(enable_model_summary=False)
@ -42,7 +42,8 @@ def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
trainer = Trainer(enable_model_summary=True, weights_summary=None)
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

View File

@ -153,7 +153,7 @@ def test_trainer_properties_restore_resume_from_checkpoint(tmpdir):
state_dict = torch.load(resume_ckpt)
trainer_args.update(
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "checkpoint_callback": False, "callbacks": []}
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "enable_checkpointing": False, "callbacks": []}
)
class CustomClassifModel(CustomClassifModel):

View File

@ -468,7 +468,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
"max_epochs": 3,
"max_steps": num_train_samples + 10,
"logger": False,
"weights_summary": None,
"enable_model_summary": False,
"enable_progress_bar": False,
}
trainer = Trainer(**trainer_kwargs)
@ -555,7 +555,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
# define less min steps than 1 epoch
"min_steps": num_train_samples // 2,
"logger": False,
"weights_summary": None,
"enable_model_summary": False,
"enable_progress_bar": False,
}
trainer = Trainer(**trainer_kwargs)
@ -723,9 +723,9 @@ def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
assert getattr(trainer, path_attr) == ckpt_path
@pytest.mark.parametrize("checkpoint_callback", (False, True))
@pytest.mark.parametrize("enable_checkpointing", (False, True))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
@ -746,7 +746,7 @@ def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
limit_predict_batches=1,
enable_progress_bar=False,
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
enable_checkpointing=enable_checkpointing,
)
trainer.fit(model)
@ -754,7 +754,7 @@ def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
if checkpoint_callback:
if enable_checkpointing:
trainer_fn(ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path

View File

@ -384,7 +384,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
"fit": {
"model": {"class_path": "tests.helpers.BoringModel"},
"data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}},
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "weights_summary": None},
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
}
}
config_path = tmpdir / "config.yaml"