Attach version_ to checkpoint path only if version is int (#1748)
This commit is contained in:
parent
0cb58fbb4c
commit
851866333c
|
@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
|
||||
|
||||
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
|
||||
|
@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676))
|
||||
|
||||
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
|
||||
|
||||
## [0.7.5] - 2020-04-27
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -49,10 +49,12 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
if self.weights_save_path is not None:
|
||||
save_dir = self.weights_save_path
|
||||
|
||||
version = self.logger.version if isinstance(
|
||||
self.logger.version, str) else f'version_{self.logger.version}'
|
||||
ckpt_path = os.path.join(
|
||||
save_dir,
|
||||
self.logger.name,
|
||||
f'version_{self.logger.version}',
|
||||
version,
|
||||
"checkpoints"
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,9 @@ import tests.base.utils as tutils
|
|||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_trainer_callback_system(tmpdir):
|
||||
|
@ -258,6 +260,28 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
|||
assert trainer.ckpt_path != trainer.default_root_dir
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'logger_version,expected',
|
||||
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
|
||||
)
|
||||
def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
||||
"""Test that "version_" prefix is only added when logger's version is an integer"""
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
overfit_pct=0.2,
|
||||
max_epochs=5,
|
||||
logger=logger
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
ckpt_version = Path(trainer.ckpt_path).parent.name
|
||||
assert ckpt_version == expected
|
||||
|
||||
|
||||
def test_lr_logger_single_lr(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler"""
|
||||
tutils.reset_seed()
|
||||
|
|
Loading…
Reference in New Issue