Added changeable extension variable for model checkpoints (#4977)

* Added changeable extension variable for model checkpoints

* Removed whitespace

* Removed the last bit of whitespace

* Wrote tests for FILE_EXTENSION

* Fixed formatting issues

* More formatting issues

* Simplify test by just using defaults

* Formatting to PEP8

* Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test

* Fixed too much whitespace formatting

* some changes

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
Jan-Henrik Lambrechts 2020-12-07 01:28:50 +08:00 committed by GitHub
parent 2e838e6dd8
commit b00991efd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 3 deletions

View File

@ -90,7 +90,7 @@ class ModelCheckpoint(Callback):
Example::
# custom path
# saves a file like: my/path/epoch=0.ckpt
# saves a file like: my/path/epoch=0-step=10.ckpt
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is ``None`` and will be set at runtime to the location
@ -140,6 +140,7 @@ class ModelCheckpoint(Callback):
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
def __init__(
self,
@ -442,7 +443,7 @@ class ModelCheckpoint(Callback):
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
ckpt_name = f"{filename}.ckpt"
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
def __resolve_ckpt_dir(self, trainer, pl_module):
@ -545,7 +546,7 @@ class ModelCheckpoint(Callback):
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
self._save_model(last_filepath, trainer, pl_module)
if (

View File

@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
def test_model_checkpoint_file_extension(tmpdir):
"""
Test ModelCheckpoint with different file extension.
"""
model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_steps=1,
logger=False,
)
trainer.fit(model)
expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
assert set(expected) == set(os.listdir(tmpdir))
def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()