parent
2b3f443f6b
commit
1d5f06223a
|
@ -249,7 +249,7 @@ def test_dp_output_reduce():
|
||||||
assert reduced['b']['c'] == out['b']['c']
|
assert reduced['b']['c'] == out['b']['c']
|
||||||
|
|
||||||
|
|
||||||
def test_model_checkpoint_options(tmp_path):
|
def test_model_checkpoint_options(tmpdir):
|
||||||
"""Test ModelCheckpoint options."""
|
"""Test ModelCheckpoint options."""
|
||||||
def mock_save_function(filepath):
|
def mock_save_function(filepath):
|
||||||
open(filepath, 'a').close()
|
open(filepath, 'a').close()
|
||||||
|
@ -258,8 +258,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
_ = LightningTestModel(hparams)
|
_ = LightningTestModel(hparams)
|
||||||
|
|
||||||
# simulated losses
|
# simulated losses
|
||||||
save_dir = tmp_path / "1"
|
save_dir = os.path.join(tmpdir, '1')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
losses = [10, 9, 2.8, 5, 2.5]
|
losses = [10, 9, 2.8, 5, 2.5]
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
|
@ -286,8 +286,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
'epoch=0.ckpt'}:
|
'epoch=0.ckpt'}:
|
||||||
assert fname in file_lists
|
assert fname in file_lists
|
||||||
|
|
||||||
save_dir = tmp_path / "2"
|
save_dir = os.path.join(tmpdir, '2')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
# CASE K=0 (none)
|
# CASE K=0 (none)
|
||||||
|
@ -305,8 +305,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
|
|
||||||
assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"
|
assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"
|
||||||
|
|
||||||
save_dir = tmp_path / "3"
|
save_dir = os.path.join(tmpdir, '3')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
# CASE K=1 (2.5, epoch 4)
|
# CASE K=1 (2.5, epoch 4)
|
||||||
|
@ -325,8 +325,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
||||||
assert 'test_prefix_epoch=4.ckpt' in file_lists
|
assert 'test_prefix_epoch=4.ckpt' in file_lists
|
||||||
|
|
||||||
save_dir = tmp_path / "4"
|
save_dir = os.path.join(tmpdir, '4')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
|
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
|
||||||
|
@ -351,8 +351,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
'other_file.ckpt'}:
|
'other_file.ckpt'}:
|
||||||
assert fname in file_lists
|
assert fname in file_lists
|
||||||
|
|
||||||
save_dir = tmp_path / "5"
|
save_dir = os.path.join(tmpdir, '5')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
# CASE K=4 (save all 4 models)
|
# CASE K=4 (save all 4 models)
|
||||||
|
@ -372,8 +372,8 @@ def test_model_checkpoint_options(tmp_path):
|
||||||
|
|
||||||
assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'
|
assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'
|
||||||
|
|
||||||
save_dir = tmp_path / "6"
|
save_dir = os.path.join(tmpdir, '6')
|
||||||
save_dir.mkdir()
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
# -----------------
|
# -----------------
|
||||||
# CASE K=3 (save the 2nd, 3rd, 4th model)
|
# CASE K=3 (save the 2nd, 3rd, 4th model)
|
||||||
|
|
Loading…
Reference in New Issue