Use non-deprecated options in tests (#9949)
This commit is contained in:
parent
db4e770004
commit
e973bcb76a
|
@ -46,7 +46,7 @@ def test_device_stats_gpu_from_torch(tmpdir):
|
|||
gpus=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
checkpoint_callback=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
|
@ -75,7 +75,7 @@ def test_device_stats_gpu_from_nvidia(tmpdir):
|
|||
gpus=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
checkpoint_callback=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
|
@ -104,7 +104,7 @@ def test_device_stats_monitor_tpu(tmpdir):
|
|||
log_every_n_steps=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
checkpoint_callback=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
|
@ -122,7 +122,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
|
|||
callbacks=[device_stats],
|
||||
max_epochs=1,
|
||||
logger=False,
|
||||
checkpoint_callback=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ def test_model_summary_callback_present_trainer():
|
|||
|
||||
|
||||
def test_model_summary_callback_with_weights_summary_none():
|
||||
|
||||
trainer = Trainer(weights_summary=None)
|
||||
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
|
||||
trainer = Trainer(weights_summary=None)
|
||||
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
trainer = Trainer(enable_model_summary=False)
|
||||
|
@ -42,7 +42,8 @@ def test_model_summary_callback_with_weights_summary_none():
|
|||
trainer = Trainer(enable_model_summary=False, weights_summary="full")
|
||||
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
trainer = Trainer(enable_model_summary=True, weights_summary=None)
|
||||
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
|
||||
trainer = Trainer(enable_model_summary=True, weights_summary=None)
|
||||
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
|
||||
|
|
|
@ -153,7 +153,7 @@ def test_trainer_properties_restore_resume_from_checkpoint(tmpdir):
|
|||
state_dict = torch.load(resume_ckpt)
|
||||
|
||||
trainer_args.update(
|
||||
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "checkpoint_callback": False, "callbacks": []}
|
||||
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "enable_checkpointing": False, "callbacks": []}
|
||||
)
|
||||
|
||||
class CustomClassifModel(CustomClassifModel):
|
||||
|
|
|
@ -468,7 +468,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
|
|||
"max_epochs": 3,
|
||||
"max_steps": num_train_samples + 10,
|
||||
"logger": False,
|
||||
"weights_summary": None,
|
||||
"enable_model_summary": False,
|
||||
"enable_progress_bar": False,
|
||||
}
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
|
@ -555,7 +555,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
|
|||
# define less min steps than 1 epoch
|
||||
"min_steps": num_train_samples // 2,
|
||||
"logger": False,
|
||||
"weights_summary": None,
|
||||
"enable_model_summary": False,
|
||||
"enable_progress_bar": False,
|
||||
}
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
|
@ -723,9 +723,9 @@ def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
|
|||
assert getattr(trainer, path_attr) == ckpt_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint_callback", (False, True))
|
||||
@pytest.mark.parametrize("enable_checkpointing", (False, True))
|
||||
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
|
||||
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
|
||||
def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
|
||||
class TestModel(BoringModel):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.log("foo", -batch_idx)
|
||||
|
@ -746,7 +746,7 @@ def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
|
|||
limit_predict_batches=1,
|
||||
enable_progress_bar=False,
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
enable_checkpointing=enable_checkpointing,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -754,7 +754,7 @@ def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
|
|||
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
|
||||
assert getattr(trainer, path_attr) is None
|
||||
|
||||
if checkpoint_callback:
|
||||
if enable_checkpointing:
|
||||
trainer_fn(ckpt_path="best")
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
|
||||
|
|
|
@ -384,7 +384,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
|||
"fit": {
|
||||
"model": {"class_path": "tests.helpers.BoringModel"},
|
||||
"data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}},
|
||||
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "weights_summary": None},
|
||||
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
|
||||
}
|
||||
}
|
||||
config_path = tmpdir / "config.yaml"
|
||||
|
|
Loading…
Reference in New Issue