Attach version_ to checkpoint path only if version is int (#1748)

This commit is contained in:
Peter Yu 2020-05-06 12:38:32 -04:00 committed by GitHub
parent 0cb58fbb4c
commit 851866333c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 2 deletions

View File

@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).
### Changed
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676))
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
## [0.7.5] - 2020-04-27
### Changed

View File

@ -49,10 +49,12 @@ class TrainerCallbackConfigMixin(ABC):
if self.weights_save_path is not None:
save_dir = self.weights_save_path
version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(
save_dir,
self.logger.name,
f'version_{self.logger.version}',
version,
"checkpoints"
)
else:

View File

@ -3,7 +3,9 @@ import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path
def test_trainer_callback_system(tmpdir):
@ -258,6 +260,28 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
assert trainer.ckpt_path != trainer.default_root_dir
@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate(tutils.get_default_hparams())
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
trainer = Trainer(
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger
)
trainer.fit(model)
ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()