Add step index in checkpoint name (#3807)
* true final value of global step * ch check * tests * save each validation interval * wip * add test * add test * wip * fix tests, revert old edits, fix merge conflicts, update doctests * test + bugfix * sort files * format test * suggestion by ananth * added changelog * naming * docs * example * suggestion Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * fix test * pep * pep Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
f40d08679d
commit
ef03c39ab7
|
@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
|
||||
|
||||
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
|
||||
|
||||
### Changed
|
||||
|
||||
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
|
||||
|
|
|
@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
|
|||
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
||||
... )
|
||||
|
||||
By default, filename is ``None`` and will be set to ``'{epoch}'``.
|
||||
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
|
||||
|
||||
|
||||
Example::
|
||||
|
@ -222,16 +222,16 @@ class ModelCheckpoint(Callback):
|
|||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
|
||||
# ie: path/val_loss=0.5.ckpt
|
||||
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
|
||||
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
|
||||
|
||||
# callback supports multiple simultaneous modes
|
||||
# here we call each mode sequentially
|
||||
# Mode 1: save all checkpoints OR only the top k
|
||||
if self.save_top_k:
|
||||
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
|
||||
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
|
||||
|
||||
# Mode 2: save the last checkpoint
|
||||
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
|
||||
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
|
||||
|
||||
def __validate_init_configuration(self):
|
||||
if self.save_top_k is not None and self.save_top_k < -1:
|
||||
|
@ -360,16 +360,17 @@ class ModelCheckpoint(Callback):
|
|||
cls,
|
||||
filename: Optional[str],
|
||||
epoch: int,
|
||||
step: int,
|
||||
metrics: Dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> str:
|
||||
if not filename:
|
||||
# filename is not set, use default name
|
||||
filename = "{epoch}"
|
||||
filename = "{epoch}-{step}"
|
||||
# check and parse user passed keys in the string
|
||||
groups = re.findall(r"(\{.*?)[:\}]", filename)
|
||||
if len(groups) >= 0:
|
||||
metrics["epoch"] = epoch
|
||||
metrics.update({"epoch": epoch, 'step': step})
|
||||
for group in groups:
|
||||
name = group[1:]
|
||||
filename = filename.replace(group, name + "={" + name)
|
||||
|
@ -379,7 +380,7 @@ class ModelCheckpoint(Callback):
|
|||
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])
|
||||
|
||||
def format_checkpoint_name(
|
||||
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
|
||||
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
|
||||
) -> str:
|
||||
"""Generate a filename according to the defined template.
|
||||
|
||||
|
@ -387,24 +388,24 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
>>> tmpdir = os.path.dirname(__file__)
|
||||
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
|
||||
'epoch=0.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
|
||||
'epoch=005.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
|
||||
'epoch=2-val_loss=0.12.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
|
||||
'missing=0.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(filename='{epoch}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
'epoch=0.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(filename='{step}')
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
|
||||
'step=0.ckpt'
|
||||
|
||||
"""
|
||||
filename = self._format_checkpoint_name(
|
||||
self.filename, epoch, metrics, prefix=self.prefix
|
||||
self.filename, epoch, step, metrics, prefix=self.prefix
|
||||
)
|
||||
if ver is not None:
|
||||
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
|
||||
|
@ -479,13 +480,11 @@ class ModelCheckpoint(Callback):
|
|||
)
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
|
||||
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
|
||||
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
|
||||
version_cnt = 0
|
||||
while self._fs.exists(filepath):
|
||||
filepath = self.format_checkpoint_name(
|
||||
epoch, ckpt_name_metrics, ver=version_cnt
|
||||
)
|
||||
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
|
||||
# this epoch called before
|
||||
version_cnt += 1
|
||||
return filepath
|
||||
|
@ -494,9 +493,10 @@ class ModelCheckpoint(Callback):
|
|||
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
|
||||
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
|
||||
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
|
||||
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
|
||||
return ckpt_name_metrics
|
||||
|
||||
def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
|
||||
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
|
||||
should_save_last = self.monitor is None or self.save_last
|
||||
if not should_save_last:
|
||||
return
|
||||
|
@ -506,7 +506,11 @@ class ModelCheckpoint(Callback):
|
|||
# when user ALSO asked for the 'last.ckpt' change the name
|
||||
if self.save_last:
|
||||
last_filepath = self._format_checkpoint_name(
|
||||
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
|
||||
self.CHECKPOINT_NAME_LAST,
|
||||
trainer.current_epoch,
|
||||
trainer.global_step,
|
||||
ckpt_name_metrics,
|
||||
prefix=self.prefix
|
||||
)
|
||||
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
|
||||
|
||||
|
@ -523,17 +527,19 @@ class ModelCheckpoint(Callback):
|
|||
if self.monitor is None:
|
||||
self.best_model_path = self.last_model_path
|
||||
|
||||
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
|
||||
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
|
||||
current = metrics.get(self.monitor)
|
||||
epoch = metrics.get("epoch")
|
||||
step = metrics.get("step")
|
||||
|
||||
if not isinstance(current, torch.Tensor) and current is not None:
|
||||
current = torch.tensor(current, device=pl_module.device)
|
||||
|
||||
if self.check_monitor_top_k(current):
|
||||
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
|
||||
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
|
||||
elif self.verbose:
|
||||
rank_zero_info(
|
||||
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
|
||||
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
|
||||
)
|
||||
|
||||
def _is_valid_monitor_key(self, metrics):
|
||||
|
@ -544,11 +550,11 @@ class ModelCheckpoint(Callback):
|
|||
filepath: str,
|
||||
current: torch.Tensor,
|
||||
epoch: int,
|
||||
step: int,
|
||||
trainer,
|
||||
pl_module,
|
||||
):
|
||||
|
||||
k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
|
||||
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
|
||||
|
||||
del_list = []
|
||||
if len(self.best_k_models) == k and k > 0:
|
||||
|
@ -575,9 +581,8 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
if self.verbose:
|
||||
rank_zero_info(
|
||||
f"Epoch {epoch:d}: {self.monitor} reached"
|
||||
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
|
||||
f" saving model to {filepath} as top {k}"
|
||||
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
|
||||
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
|
||||
)
|
||||
self._save_model(filepath, trainer, pl_module)
|
||||
|
||||
|
|
|
@ -250,9 +250,10 @@ class EvaluationLoop(object):
|
|||
# depre warning
|
||||
if eval_results is not None and user_reduced:
|
||||
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
|
||||
m = f'The {step} should not return anything as of 9.1.' \
|
||||
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
|
||||
self.warning_cache.warn(m)
|
||||
self.warning_cache.warn(
|
||||
f'The {step} should not return anything as of 9.1.'
|
||||
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
|
||||
)
|
||||
|
||||
if using_eval_result and not user_reduced:
|
||||
eval_results = self.__auto_reduce_result_objs(outputs)
|
||||
|
|
|
@ -100,7 +100,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
|
|||
path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
|
||||
checkpoint.to_yaml(path_yaml)
|
||||
d = yaml.full_load(open(path_yaml, 'r'))
|
||||
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
|
||||
best_k = {k: v for k, v in checkpoint.best_k_models.items()}
|
||||
assert d == best_k
|
||||
|
||||
|
||||
|
@ -185,67 +185,72 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
|||
|
||||
def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
||||
# empty filename:
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {})
|
||||
assert ckpt_name == 'epoch=3'
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
|
||||
assert ckpt_name == 'epoch=3-step=2'
|
||||
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test')
|
||||
assert ckpt_name == 'test-epoch=3'
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test')
|
||||
assert ckpt_name == 'test-epoch=3-step=2'
|
||||
|
||||
# no groups case:
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test')
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test')
|
||||
assert ckpt_name == 'test-ckpt'
|
||||
|
||||
# no prefix
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03})
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
|
||||
assert ckpt_name == 'epoch=003-acc=0.03'
|
||||
|
||||
# prefix
|
||||
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
|
||||
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test')
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test')
|
||||
assert ckpt_name == 'test@epoch=3,acc=0.03000'
|
||||
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
|
||||
|
||||
# no dirpath set
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, {})
|
||||
assert ckpt_name == 'epoch=3.ckpt'
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, {})
|
||||
assert ckpt_name == 'epoch=5.ckpt'
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {})
|
||||
assert ckpt_name == 'epoch=3-step=2.ckpt'
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {})
|
||||
assert ckpt_name == 'epoch=5-step=4.ckpt'
|
||||
|
||||
# CWD
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, {})
|
||||
assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt')
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {})
|
||||
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')
|
||||
|
||||
# with ver
|
||||
ckpt_name = ModelCheckpoint(
|
||||
monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test'
|
||||
).format_checkpoint_name(3, {}, ver=3)
|
||||
).format_checkpoint_name(3, 2, {}, ver=3)
|
||||
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
|
||||
|
||||
# using slashes
|
||||
ckpt_name = ModelCheckpoint(
|
||||
monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}'
|
||||
).format_checkpoint_name(4, {'val/loss': 0.03})
|
||||
).format_checkpoint_name(4, 3, {'val/loss': 0.03})
|
||||
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'
|
||||
|
||||
# TODO: Checks with filepath. To be removed in v1.2
|
||||
# CWD
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {})
|
||||
assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt')
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {})
|
||||
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt')
|
||||
|
||||
# dir does not exist so it is used as filename
|
||||
filepath = tmpdir / 'dir'
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
|
||||
ckpt_name = ModelCheckpoint(
|
||||
monitor='early_stop_on', filepath=filepath, prefix='test'
|
||||
).format_checkpoint_name(3, 2, {})
|
||||
assert ckpt_name == tmpdir / 'test-dir.ckpt'
|
||||
|
||||
# now, dir exists
|
||||
os.mkdir(filepath)
|
||||
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
|
||||
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
|
||||
ckpt_name = ModelCheckpoint(
|
||||
monitor='early_stop_on', filepath=filepath, prefix='test'
|
||||
).format_checkpoint_name(3, 2, {})
|
||||
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last(tmpdir):
|
||||
"""Tests that save_last produces only one last checkpoint."""
|
||||
seed_everything()
|
||||
model = EvalModelTemplate()
|
||||
epochs = 3
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
|
||||
|
@ -257,10 +262,15 @@ def test_model_checkpoint_save_last(tmpdir):
|
|||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
|
||||
last_filename = model_checkpoint._format_checkpoint_name(
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {}
|
||||
)
|
||||
last_filename = last_filename + '.ckpt'
|
||||
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
|
||||
assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
|
||||
assert set(os.listdir(tmpdir)) == set(
|
||||
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename]
|
||||
)
|
||||
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
|
||||
|
||||
|
||||
|
@ -295,6 +305,7 @@ def test_none_monitor_save_last(tmpdir):
|
|||
|
||||
def test_model_checkpoint_none_monitor(tmpdir):
|
||||
""" Test that it is possible to save all checkpoints when monitor=None. """
|
||||
seed_everything()
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = model.validation_step_no_monitor
|
||||
model.validation_epoch_end = model.validation_epoch_end_no_monitor
|
||||
|
@ -311,13 +322,13 @@ def test_model_checkpoint_none_monitor(tmpdir):
|
|||
|
||||
# these should not be set if monitor is None
|
||||
assert checkpoint_callback.monitor is None
|
||||
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt'
|
||||
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt'
|
||||
assert checkpoint_callback.best_model_score == 0
|
||||
assert checkpoint_callback.best_k_models == {}
|
||||
assert checkpoint_callback.kth_best_model_path == ''
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs)]
|
||||
expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])]
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
|
@ -325,13 +336,14 @@ def test_model_checkpoint_none_monitor(tmpdir):
|
|||
def test_model_checkpoint_period(tmpdir, period):
|
||||
model = EvalModelTemplate()
|
||||
epochs = 5
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, period=period)
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
max_epochs=epochs,
|
||||
limit_train_batches=0.1,
|
||||
limit_val_batches=0.1,
|
||||
val_check_interval=1.0,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
@ -372,12 +384,19 @@ def test_model_checkpoint_topk_all(tmpdir):
|
|||
return {'epoch': self.current_epoch}
|
||||
|
||||
model = CustomModel()
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", mode='max', save_top_k=-1)
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=tmpdir,
|
||||
filename="{epoch}",
|
||||
monitor="epoch",
|
||||
mode='max',
|
||||
save_top_k=-1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
max_epochs=epochs,
|
||||
logger=False,
|
||||
val_check_interval=1.0,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -439,7 +458,7 @@ def test_default_checkpoint_behavior(tmpdir):
|
|||
# make sure the checkpoint we saved has the metric in the name
|
||||
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
|
||||
assert len(ckpts) == 1
|
||||
assert ckpts[0] == 'epoch=2.ckpt'
|
||||
assert ckpts[0] == 'epoch=2-step=14.ckpt'
|
||||
|
||||
|
||||
def test_ckpt_metric_names_results(tmpdir):
|
||||
|
@ -497,7 +516,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
model = EvalModelTemplate()
|
||||
num_epochs = 3
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
monitor='early_stop_on', dirpath=tmpdir, save_top_k=num_epochs, save_last=True
|
||||
monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -509,6 +528,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
|
||||
path_last = str(tmpdir / "last.ckpt")
|
||||
assert path_last == model_checkpoint.last_model_path
|
||||
assert os.path.isfile(path_last_epoch)
|
||||
|
||||
ckpt_last_epoch = torch.load(path_last_epoch)
|
||||
ckpt_last = torch.load(path_last)
|
||||
|
@ -791,3 +811,25 @@ def test_configure_model_checkpoint(tmpdir):
|
|||
|
||||
with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
|
||||
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
|
||||
|
||||
|
||||
def test_val_check_interval_checkpoint_files(tmpdir):
|
||||
""" Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """
|
||||
model = EvalModelTemplate()
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
dirpath=tmpdir,
|
||||
save_top_k=-1,
|
||||
monitor="val_acc",
|
||||
mode="max",
|
||||
verbose=True
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
val_check_interval=0.2,
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
callbacks=[model_checkpoint]
|
||||
)
|
||||
trainer.fit(model)
|
||||
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
|
||||
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]
|
||||
|
|
|
@ -159,7 +159,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch
|
|||
trainer.fit(model)
|
||||
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
|
||||
|
||||
|
||||
@patch('pytorch_lightning.loggers.comet.comet_ml')
|
||||
|
|
|
@ -115,7 +115,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
|
||||
|
||||
|
||||
def test_mlflow_logger_dirs_creation(tmpdir):
|
||||
|
@ -143,7 +143,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
|
|||
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
|
||||
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
|
||||
|
|
|
@ -116,7 +116,7 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
|
|
|
@ -430,7 +430,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, ex
|
|||
losses = [10, 9, 2.8, 5, 2.5]
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=tmpdir, monitor='checkpoint_on', save_top_k=save_top_k,
|
||||
dirpath=tmpdir, filename='{epoch}', monitor='checkpoint_on', save_top_k=save_top_k,
|
||||
save_last=save_last, prefix=file_prefix, verbose=1
|
||||
)
|
||||
checkpoint_callback.save_function = mock_save_function
|
||||
|
|
Loading…
Reference in New Issue