Fix test interactions (#18994)
This commit is contained in:
parent
3acea8d157
commit
340961a6ec
|
@ -214,10 +214,10 @@ def test_tensorboard_finalize(monkeypatch, tmp_path):
|
|||
|
||||
|
||||
@mock.patch("lightning.fabric.loggers.tensorboard.log")
|
||||
def test_tensorboard_with_symlink(log, tmp_path):
|
||||
def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
|
||||
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
|
||||
relative paths."""
|
||||
os.chdir(tmp_path) # need to use relative paths
|
||||
monkeypatch.chdir(tmp_path) # need to use relative paths
|
||||
source = os.path.join(".", "lightning_logs")
|
||||
dest = os.path.join(".", "sym_lightning_logs")
|
||||
|
||||
|
|
|
@ -401,7 +401,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
|||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
||||
def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
|
||||
# empty filename:
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
|
||||
assert ckpt_name == "epoch=3-step=2"
|
||||
|
@ -422,18 +422,16 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
|||
assert ckpt_name == "epoch=003-epoch_test=003"
|
||||
|
||||
# prefix
|
||||
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
|
||||
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = "@"
|
||||
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@")
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
|
||||
assert ckpt_name == "test@epoch=3,acc=0.03000"
|
||||
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
|
||||
monkeypatch.undo()
|
||||
|
||||
# non-default char for equals sign
|
||||
default_char = ModelCheckpoint.CHECKPOINT_EQUALS_CHAR
|
||||
ModelCheckpoint.CHECKPOINT_EQUALS_CHAR = ":"
|
||||
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":")
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
|
||||
assert ckpt_name == "epoch:003-acc:0.03"
|
||||
ModelCheckpoint.CHECKPOINT_EQUALS_CHAR = default_char
|
||||
monkeypatch.undo()
|
||||
|
||||
# no dirpath set
|
||||
ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=None).format_checkpoint_name({"epoch": 3, "step": 2})
|
||||
|
|
|
@ -313,10 +313,10 @@ def test_tensorboard_save_hparams_to_yaml_once(tmp_path):
|
|||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.tensorboard.log")
|
||||
def test_tensorboard_with_symlink(log, tmp_path):
|
||||
def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
|
||||
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
|
||||
relative paths."""
|
||||
os.chdir(tmp_path) # need to use relative paths
|
||||
monkeypatch.chdir(tmp_path) # need to use relative paths
|
||||
source = os.path.join(".", "lightning_logs")
|
||||
dest = os.path.join(".", "sym_lightning_logs")
|
||||
|
||||
|
|
Loading…
Reference in New Issue