parent
2b3f443f6b
commit
1d5f06223a
|
@ -249,7 +249,7 @@ def test_dp_output_reduce():
|
|||
assert reduced['b']['c'] == out['b']['c']
|
||||
|
||||
|
||||
def test_model_checkpoint_options(tmp_path):
|
||||
def test_model_checkpoint_options(tmpdir):
|
||||
"""Test ModelCheckpoint options."""
|
||||
def mock_save_function(filepath):
|
||||
open(filepath, 'a').close()
|
||||
|
@ -258,8 +258,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
_ = LightningTestModel(hparams)
|
||||
|
||||
# simulated losses
|
||||
save_dir = tmp_path / "1"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '1')
|
||||
os.mkdir(save_dir)
|
||||
losses = [10, 9, 2.8, 5, 2.5]
|
||||
|
||||
# -----------------
|
||||
|
@ -286,8 +286,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
'epoch=0.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
save_dir = tmp_path / "2"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '2')
|
||||
os.mkdir(save_dir)
|
||||
|
||||
# -----------------
|
||||
# CASE K=0 (none)
|
||||
|
@ -305,8 +305,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
|
||||
assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"
|
||||
|
||||
save_dir = tmp_path / "3"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '3')
|
||||
os.mkdir(save_dir)
|
||||
|
||||
# -----------------
|
||||
# CASE K=1 (2.5, epoch 4)
|
||||
|
@ -325,8 +325,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
||||
assert 'test_prefix_epoch=4.ckpt' in file_lists
|
||||
|
||||
save_dir = tmp_path / "4"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '4')
|
||||
os.mkdir(save_dir)
|
||||
|
||||
# -----------------
|
||||
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
|
||||
|
@ -351,8 +351,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
'other_file.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
save_dir = tmp_path / "5"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '5')
|
||||
os.mkdir(save_dir)
|
||||
|
||||
# -----------------
|
||||
# CASE K=4 (save all 4 models)
|
||||
|
@ -372,8 +372,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
|
||||
assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'
|
||||
|
||||
save_dir = tmp_path / "6"
|
||||
save_dir.mkdir()
|
||||
save_dir = os.path.join(tmpdir, '6')
|
||||
os.mkdir(save_dir)
|
||||
|
||||
# -----------------
|
||||
# CASE K=3 (save the 2nd, 3rd, 4th model)
|
||||
|
|
Loading…
Reference in New Issue