diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7419e1746f..8a550e5df9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -418,7 +418,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything. Any keys present in 'log', 'progress_bar' or the rest of the dictionary - are available for callbacks to access. + are available for callbacks to access. If you want to manually set current step, you can specify it with + 'step' key in the 'log' Dict. Example ------- @@ -468,7 +469,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): # show val_loss and val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, - 'log': {'val_loss': val_loss_mean.item()} + 'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch} } return results diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 2025bd511e..34b1c114b3 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -39,13 +39,12 @@ class TrainerLoggingMixin(ABC): def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. - - :param metrics: - :param grad_norm_dic: + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + :param metrics (dict): Metric values + :param grad_norm_dic (dict): Gradient norms + :param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ - # added metrics by Lightning for convenience - metrics['epoch'] = self.current_epoch - # add gpu memory if self.on_gpu and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) @@ -57,7 +56,12 @@ class TrainerLoggingMixin(ABC): # turn all tensors to scalars scalar_metrics = self.metrics_to_scalars(metrics) - step = step if step is not None else self.global_step + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + else: + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + step = step if step is not None else self.global_step # log actual metrics if self.proc_rank == 0 and self.logger is not None: self.logger.log_metrics(scalar_metrics, step=step) diff --git a/tests/test_logging.py b/tests/test_logging.py index c0166796ca..0d4104ef7a 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -376,3 +376,33 @@ def test_custom_logger(tmpdir): assert logger.hparams_logged == hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success" + + +def test_adding_step_key(tmpdir): + logged_step = 0 + + def _validation_end(outputs): + nonlocal logged_step + logged_step += 1 + return {"log": {"step": logged_step, "val_acc": logged_step / 10}} + + def _log_metrics_decorator(log_metrics_fn): + def decorated(metrics, step): + if "val_acc" in metrics: + assert step == logged_step + return log_metrics_fn(metrics, step) + + return decorated + + model, hparams = tutils.get_model() + model.validation_end = _validation_end + trainer_options = dict( + max_epochs=4, + default_save_path=tmpdir, + train_percent_check=0.001, + val_percent_check=0.01, + num_sanity_val_steps=0 + ) + trainer = Trainer(**trainer_options) + trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics) + trainer.fit(model)