diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 32e643c11d..5890aa5e84 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -339,10 +339,11 @@ class ModelCheckpoint(Callback): self.epoch_last_check = epoch - filepath = self.format_checkpoint_name(epoch, metrics) + ckpt_name_metrics = trainer.logged_metrics + filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 while gfile.exists(filepath): - filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) + filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index f84b071114..47e5dbb4de 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC): default_root_dir: str slurm_job_id: int num_gpus: int + logged_metrics: ... def configure_logger(self, logger): if logger is True: @@ -75,6 +76,8 @@ class TrainerLoggingMixin(ABC): self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() + # track the logged metrics + self.logged_metrics = scalar_metrics self.dev_debugger.track_logged_metrics_history(scalar_metrics) def add_progress_bar_metrics(self, metrics): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa9b670f56..eb2e35a83f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -374,6 +374,7 @@ class Trainer( self.batch_idx = 0 self.progress_bar_metrics = {} self.callback_metrics = {} + self.logged_metrics = {} self.num_training_batches = 0 self.num_val_batches = [] self.num_test_batches = [] diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index dfcc1e0368..976fc887d3 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -1,4 +1,5 @@ import os +import re import pickle import platform from pathlib import Path @@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model_last = EvalModelTemplate.load_from_checkpoint(path_last) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() + + +def test_ckpt_metric_names(tmpdir): + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gradient_clip_val=1.0, + overfit_batches=0.20, + progress_bar_refresh_rate=0, + limit_train_batches=0.01, + limit_val_batches=0.01, + checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}') + ) + + trainer.fit(model) + + # make sure the checkpoint we saved has the metric in the name + ckpts = os.listdir(tmpdir) + ckpts = [x for x in ckpts if 'val_loss' in x] + assert len(ckpts) == 1 + val = re.sub('[^0-9.]', '', ckpts[0]) + assert len(val) > 3 + + +def test_ckpt_metric_names_results(tmpdir): + model = EvalModelTemplate() + model.training_step = model.training_step_result_obj + model.training_step_end = None + model.training_epoch_end = None + + model.validation_step = model.validation_step_result_obj + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gradient_clip_val=1.0, + overfit_batches=0.20, + progress_bar_refresh_rate=0, + limit_train_batches=0.01, + limit_val_batches=0.01, + checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}') + ) + + trainer.fit(model) + + # make sure the checkpoint we saved has the metric in the name + ckpts = os.listdir(tmpdir) + ckpts = [x for x in ckpts if 'val_loss' in x] + assert len(ckpts) == 1 + val = re.sub('[^0-9.]', '', ckpts[0]) + assert len(val) > 3 diff --git a/tests/trainer/test_eval_loop_dict_return.py b/tests/trainer/test_eval_loop_dict_return.py index d4e845bade..d726f01c1c 100644 --- a/tests/trainer/test_eval_loop_dict_return.py +++ b/tests/trainer/test_eval_loop_dict_return.py @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.callback_metrics) == 8 + assert len(trainer.callback_metrics) == 7 # make sure correct steps were called assert model.validation_step_called @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.callback_metrics) == 9 + assert len(trainer.callback_metrics) == 8 # make sure correct steps were called assert model.validation_step_called @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.callback_metrics) == 9 + assert len(trainer.callback_metrics) == 8 # make sure correct steps were called assert model.validation_step_called @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.callback_metrics) == 10 + assert len(trainer.callback_metrics) == 9 # make sure correct steps were called assert model.validation_step_called diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/test_trainer_steps_scalar_return.py index 65a92a49de..40c716ac47 100644 --- a/tests/trainer/test_trainer_steps_scalar_return.py +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir): assert model.training_epoch_end_called # assert epoch end metrics were added - assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.callback_metrics) == 0 assert len(trainer.progress_bar_metrics) == 0 # make sure training outputs what is expected @@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir): assert model.training_epoch_end_called # assert epoch end metrics were added - assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.callback_metrics) == 0 assert len(trainer.progress_bar_metrics) == 0 # make sure training outputs what is expected