From 7c61fc7c27ef81354af399c04e939e57c65ce046 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 02:31:11 -0400 Subject: [PATCH] ref: fixes logging for eval steps (#3763) * fixes logging for eval steps --- pl_examples/basic_examples/image_classifier.py | 3 ++- pytorch_lightning/callbacks/model_checkpoint.py | 15 +++++---------- .../trainer/connectors/logger_connector.py | 3 +++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pl_examples/basic_examples/image_classifier.py b/pl_examples/basic_examples/image_classifier.py index c453822d02..04e965c049 100644 --- a/pl_examples/basic_examples/image_classifier.py +++ b/pl_examples/basic_examples/image_classifier.py @@ -54,13 +54,14 @@ class LitClassifier(pl.LightningModule): x, y = batch y_hat = self.backbone(x) loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) loss = F.cross_entropy(y_hat, y) - self.log('valid_loss', loss) + self.log('valid_loss', loss, on_step=True) def test_step(self, batch, batch_idx): x, y = batch diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4357163f46..f0a8e159a2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -296,6 +296,9 @@ class ModelCheckpoint(Callback): raise ValueError(".save_function() not set") def check_monitor_top_k(self, current) -> bool: + if current is None: + return False + if self.save_top_k == -1: return True @@ -421,7 +424,7 @@ class ModelCheckpoint(Callback): if self.monitor is None and 'checkpoint_on' in metrics: self.monitor = 'checkpoint_on' - if self.save_top_k is None: + if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 def _validate_monitor_key(self, trainer): @@ -486,15 +489,7 @@ class ModelCheckpoint(Callback): if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) - if current is None: - m = f"Can save best model only with {self.monitor} available, skipping." - if self.monitor == 'checkpoint_on': - m = ( - 'No checkpoint_on found. HINT: Did you set it in ' - 'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?' - ) - rank_zero_warn(m, RuntimeWarning) - elif self.check_monitor_top_k(current): + if self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, trainer, pl_module) elif self.verbose: rank_zero_info( diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 8030d57fe7..7a56639c63 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -157,6 +157,9 @@ class LoggerConnector: # track the final results for the dataloader self.eval_loop_results.append(deepcopy(self.callback_metrics)) + # actually log + self.log_metrics(logger_metrics, {}, step=self.trainer.global_step) + def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): if num_loaders == 1: return metrics