Added changeable extension variable for model checkpoints (#4977)
* Added changeable extension variable for model checkpoints * Removed whitespace * Removed the last bit of whitespace * Wrote tests for FILE_EXTENSION * Fixed formatting issues * More formatting issues * Simplify test by just using defaults * Formatting to PEP8 * Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test * Fixed too much whitespace formatting * some changes Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
2e838e6dd8
commit
b00991efd8
|
@ -90,7 +90,7 @@ class ModelCheckpoint(Callback):
|
|||
Example::
|
||||
|
||||
# custom path
|
||||
# saves a file like: my/path/epoch=0.ckpt
|
||||
# saves a file like: my/path/epoch=0-step=10.ckpt
|
||||
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
||||
|
||||
By default, dirpath is ``None`` and will be set at runtime to the location
|
||||
|
@ -140,6 +140,7 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
CHECKPOINT_JOIN_CHAR = "-"
|
||||
CHECKPOINT_NAME_LAST = "last"
|
||||
FILE_EXTENSION = ".ckpt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -442,7 +443,7 @@ class ModelCheckpoint(Callback):
|
|||
)
|
||||
if ver is not None:
|
||||
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
|
||||
ckpt_name = f"{filename}.ckpt"
|
||||
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
|
||||
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
|
||||
|
||||
def __resolve_ckpt_dir(self, trainer, pl_module):
|
||||
|
@ -545,7 +546,7 @@ class ModelCheckpoint(Callback):
|
|||
ckpt_name_metrics,
|
||||
prefix=self.prefix
|
||||
)
|
||||
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
|
||||
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
|
||||
|
||||
self._save_model(last_filepath, trainer, pl_module)
|
||||
if (
|
||||
|
|
|
@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
|||
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
|
||||
|
||||
|
||||
class ModelCheckpointExtensionTest(ModelCheckpoint):
|
||||
FILE_EXTENSION = '.tpkc'
|
||||
|
||||
|
||||
def test_model_checkpoint_file_extension(tmpdir):
|
||||
"""
|
||||
Test ModelCheckpoint with different file extension.
|
||||
"""
|
||||
|
||||
model = LogInTwoMethods()
|
||||
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[model_checkpoint],
|
||||
max_steps=1,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
|
||||
assert set(expected) == set(os.listdir(tmpdir))
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last(tmpdir):
|
||||
"""Tests that save_last produces only one last checkpoint."""
|
||||
seed_everything()
|
||||
|
|
Loading…
Reference in New Issue