Add step index in checkpoint name (#3807)

* true final value of global step

* ch check

* tests

* save each validation interval

* wip

* add test

* add test

* wip

* fix tests, revert old edits, fix merge conflicts, update doctests

* test + bugfix

* sort files

* format test

* suggestion by ananth

* added changelog

* naming

* docs

* example

* suggestion

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* fix test

* pep

* pep

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Jirka Borovec 2020-11-02 15:05:58 +01:00 committed by GitHub
parent f40d08679d
commit ef03c39ab7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 117 additions and 67 deletions

View File

@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
### Changed
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))

View File

@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
By default, filename is ``None`` and will be set to ``'{epoch}'``.
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
Example::
@ -222,16 +222,16 @@ class ModelCheckpoint(Callback):
monitor_candidates = self._monitor_candidates(trainer)
# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
@ -360,16 +360,17 @@ class ModelCheckpoint(Callback):
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}"
filename = "{epoch}-{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
@ -379,7 +380,7 @@ class ModelCheckpoint(Callback):
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])
def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
@ -387,24 +388,24 @@ class ModelCheckpoint(Callback):
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
'step=0.ckpt'
"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
@ -479,13 +480,11 @@ class ModelCheckpoint(Callback):
)
raise MisconfigurationException(m)
def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
)
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
return filepath
@ -494,9 +493,10 @@ class ModelCheckpoint(Callback):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics
def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
@ -506,7 +506,11 @@ class ModelCheckpoint(Callback):
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
@ -523,17 +527,19 @@ class ModelCheckpoint(Callback):
if self.monitor is None:
self.best_model_path = self.last_model_path
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")
if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)
if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
elif self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
)
def _is_valid_monitor_key(self, metrics):
@ -544,11 +550,11 @@ class ModelCheckpoint(Callback):
filepath: str,
current: torch.Tensor,
epoch: int,
step: int,
trainer,
pl_module,
):
k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
del_list = []
if len(self.best_k_models) == k and k > 0:
@ -575,9 +581,8 @@ class ModelCheckpoint(Callback):
if self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {k}"
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(filepath, trainer, pl_module)

View File

@ -250,9 +250,10 @@ class EvaluationLoop(object):
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)
if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)

View File

@ -100,7 +100,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
checkpoint.to_yaml(path_yaml)
d = yaml.full_load(open(path_yaml, 'r'))
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
best_k = {k: v for k, v in checkpoint.best_k_models.items()}
assert d == best_k
@ -185,67 +185,72 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
def test_model_checkpoint_format_checkpoint_name(tmpdir):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {})
assert ckpt_name == 'epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
assert ckpt_name == 'epoch=3-step=2'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test')
assert ckpt_name == 'test-epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test')
assert ckpt_name == 'test-epoch=3-step=2'
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test')
assert ckpt_name == 'test-ckpt'
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03})
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
assert ckpt_name == 'epoch=003-acc=0.03'
# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test')
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no dirpath set
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {})
assert ckpt_name == 'epoch=3-step=2.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {})
assert ckpt_name == 'epoch=5-step=4.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, {})
assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt')
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {})
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')
# with ver
ckpt_name = ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test'
).format_checkpoint_name(3, {}, ver=3)
).format_checkpoint_name(3, 2, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
# using slashes
ckpt_name = ModelCheckpoint(
monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}'
).format_checkpoint_name(4, {'val/loss': 0.03})
).format_checkpoint_name(4, 3, {'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'
# TODO: Checks with filepath. To be removed in v1.2
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {})
assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt')
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {})
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt')
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(
monitor='early_stop_on', filepath=filepath, prefix='test'
).format_checkpoint_name(3, 2, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
ckpt_name = ModelCheckpoint(
monitor='early_stop_on', filepath=filepath, prefix='test'
).format_checkpoint_name(3, 2, {})
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
@ -257,10 +262,15 @@ def test_model_checkpoint_save_last(tmpdir):
logger=False,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {}
)
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
assert set(os.listdir(tmpdir)) == set(
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename]
)
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
@ -295,6 +305,7 @@ def test_none_monitor_save_last(tmpdir):
def test_model_checkpoint_none_monitor(tmpdir):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
model = EvalModelTemplate()
model.validation_step = model.validation_step_no_monitor
model.validation_epoch_end = model.validation_epoch_end_no_monitor
@ -311,13 +322,13 @@ def test_model_checkpoint_none_monitor(tmpdir):
# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt'
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt'
assert checkpoint_callback.best_model_score == 0
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)]
expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])]
assert set(os.listdir(tmpdir)) == set(expected)
@ -325,13 +336,14 @@ def test_model_checkpoint_none_monitor(tmpdir):
def test_model_checkpoint_period(tmpdir, period):
model = EvalModelTemplate()
epochs = 5
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, period=period)
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
logger=False,
)
trainer.fit(model)
@ -372,12 +384,19 @@ def test_model_checkpoint_topk_all(tmpdir):
return {'epoch': self.current_epoch}
model = CustomModel()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", mode='max', save_top_k=-1)
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename="{epoch}",
monitor="epoch",
mode='max',
save_top_k=-1,
)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
logger=False,
val_check_interval=1.0,
)
trainer.fit(model)
@ -439,7 +458,7 @@ def test_default_checkpoint_behavior(tmpdir):
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2.ckpt'
assert ckpts[0] == 'epoch=2-step=14.ckpt'
def test_ckpt_metric_names_results(tmpdir):
@ -497,7 +516,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir, save_top_k=num_epochs, save_last=True
monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
@ -509,6 +528,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
@ -791,3 +811,25 @@ def test_configure_model_checkpoint(tmpdir):
with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
def test_val_check_interval_checkpoint_files(tmpdir):
""" Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """
model = EvalModelTemplate()
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=-1,
monitor="val_acc",
mode="max",
verbose=True
)
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=0.2,
max_epochs=1,
limit_train_batches=10,
callbacks=[model_checkpoint]
)
trainer.fit(model)
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]

View File

@ -159,7 +159,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
@patch('pytorch_lightning.loggers.comet.comet_ml')

View File

@ -115,7 +115,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
)
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
def test_mlflow_logger_dirs_creation(tmpdir):
@ -143,7 +143,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')

View File

@ -116,7 +116,7 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
def test_wandb_sanitize_callable_params(tmpdir):

View File

@ -430,7 +430,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, ex
losses = [10, 9, 2.8, 5, 2.5]
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, monitor='checkpoint_on', save_top_k=save_top_k,
dirpath=tmpdir, filename='{epoch}', monitor='checkpoint_on', save_top_k=save_top_k,
save_last=save_last, prefix=file_prefix, verbose=1
)
checkpoint_callback.save_function = mock_save_function