diff --git a/CHANGELOG.md b/CHANGELOG.md index bda687d2bb..f3190f9d73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). ### Changed - + - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) - Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676)) +- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748)) + ## [0.7.5] - 2020-04-27 ### Changed diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 39c6796316..a760b97602 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -49,10 +49,12 @@ class TrainerCallbackConfigMixin(ABC): if self.weights_save_path is not None: save_dir = self.weights_save_path + version = self.logger.version if isinstance( + self.logger.version, str) else f'version_{self.logger.version}' ckpt_path = os.path.join( save_dir, self.logger.name, - f'version_{self.logger.version}', + version, "checkpoints" ) else: diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8a50cb667c..884fc82e13 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -3,7 +3,9 @@ import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate +from pathlib import Path def test_trainer_callback_system(tmpdir): @@ -258,6 +260,28 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): assert trainer.ckpt_path != trainer.default_root_dir +@pytest.mark.parametrize( + 'logger_version,expected', + [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], +) +def test_model_checkpoint_path(tmpdir, logger_version, expected): + """Test that "version_" prefix is only added when logger's version is an integer""" + tutils.reset_seed() + model = EvalModelTemplate(tutils.get_default_hparams()) + logger = TensorBoardLogger(str(tmpdir), version=logger_version) + + trainer = Trainer( + default_root_dir=tmpdir, + overfit_pct=0.2, + max_epochs=5, + logger=logger + ) + trainer.fit(model) + + ckpt_version = Path(trainer.ckpt_path).parent.name + assert ckpt_version == expected + + def test_lr_logger_single_lr(tmpdir): """ Test that learning rates are extracted and logged for single lr scheduler""" tutils.reset_seed()