From b00991efd8d6b7d1941d0eb3c1a499f95b4a3eea Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Mon, 7 Dec 2020 01:28:50 +0800 Subject: [PATCH] Added changeable extension variable for model checkpoints (#4977) * Added changeable extension variable for model checkpoints * Removed whitespace * Removed the last bit of whitespace * Wrote tests for FILE_EXTENSION * Fixed formatting issues * More formatting issues * Simplify test by just using defaults * Formatting to PEP8 * Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test * Fixed too much whitespace formatting * some changes Co-authored-by: rohitgr7 --- .../callbacks/model_checkpoint.py | 7 +++--- tests/checkpointing/test_model_checkpoint.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 79feba5a41..eb669736ad 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -90,7 +90,7 @@ class ModelCheckpoint(Callback): Example:: # custom path - # saves a file like: my/path/epoch=0.ckpt + # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') By default, dirpath is ``None`` and will be set at runtime to the location @@ -140,6 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" + FILE_EXTENSION = ".ckpt" def __init__( self, @@ -442,7 +443,7 @@ class ModelCheckpoint(Callback): ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) - ckpt_name = f"{filename}.ckpt" + ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name def __resolve_ckpt_dir(self, trainer, pl_module): @@ -545,7 +546,7 @@ class ModelCheckpoint(Callback): ckpt_name_metrics, prefix=self.prefix ) - last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") + last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") self._save_model(last_filepath, trainer, pl_module) if ( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894..6d1d3edea5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' +class ModelCheckpointExtensionTest(ModelCheckpoint): + FILE_EXTENSION = '.tpkc' + + +def test_model_checkpoint_file_extension(tmpdir): + """ + Test ModelCheckpoint with different file extension. + """ + + model = LogInTwoMethods() + model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_steps=1, + logger=False, + ) + trainer.fit(model) + + expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + assert set(expected) == set(os.listdir(tmpdir)) + + def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything()