From b8371fa56c0c394722793a7cd2f1cfea5f2a74ca Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 15 Aug 2020 08:36:00 -0400 Subject: [PATCH] Fixes #2972 #2946 (#2986) * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add val step arg to metrics * add step metrics * add step metrics --- pytorch_lightning/trainer/evaluation_loop.py | 13 +++++++++---- pytorch_lightning/trainer/logging.py | 4 +++- .../test_validation_steps_result_return.py | 18 ++++++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a6451d3610..3b07b81dae 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -217,7 +217,7 @@ class TrainerEvaluationLoopMixin(ABC): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def log_metrics(self, *args): + def log_metrics(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -379,7 +379,7 @@ class TrainerEvaluationLoopMixin(ABC): dl_outputs.append(output) - self.__eval_add_step_metrics(output) + self.__eval_add_step_metrics(output, batch_idx) # track debug metrics self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) @@ -505,14 +505,19 @@ class TrainerEvaluationLoopMixin(ABC): eval_results = eval_results[0] return eval_results - def __eval_add_step_metrics(self, output): + def __eval_add_step_metrics(self, output, batch_idx): # track step level metrics if isinstance(output, EvalResult) and not self.running_sanity_check: step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: - self.log_metrics(step_log_metrics, {}) + # make the metrics appear as a different line in the same graph + metrics_by_epoch = {} + for k, v in step_log_metrics.items(): + metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v + + self.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index c90ba59abf..f84b071114 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -64,10 +64,12 @@ class TrainerLoggingMixin(ABC): if "step" in scalar_metrics and step is None: step = scalar_metrics.pop("step") - else: + + elif step is None: # added metrics by Lightning for convenience scalar_metrics['epoch'] = self.current_epoch step = step if step is not None else self.global_step + # log actual metrics if self.is_global_zero and self.logger is not None: self.logger.agg_and_log_metrics(scalar_metrics, step=step) diff --git a/tests/trainer/test_validation_steps_result_return.py b/tests/trainer/test_validation_steps_result_return.py index 8162f57287..28f012535d 100644 --- a/tests/trainer/test_validation_steps_result_return.py +++ b/tests/trainer/test_validation_steps_result_return.py @@ -214,12 +214,15 @@ def test_val_step_only_step_metrics(tmpdir): # make sure we logged the correct epoch metrics total_empty_epoch_metrics = 0 + epoch = 0 for metric in trainer.dev_debugger.logged_metrics: + if 'epoch' in metric: + epoch += 1 if len(metric) > 2: assert 'no_val_no_pbar' not in metric assert 'val_step_pbar_acc' not in metric - assert metric['val_step_log_acc'] - assert metric['val_step_log_pbar_acc'] + assert metric[f'val_step_log_acc/epoch_{epoch}'] + assert metric[f'val_step_log_pbar_acc/epoch_{epoch}'] else: total_empty_epoch_metrics += 1 @@ -228,6 +231,8 @@ def test_val_step_only_step_metrics(tmpdir): # make sure we logged the correct epoch pbar metrics total_empty_epoch_metrics = 0 for metric in trainer.dev_debugger.pbar_added_metrics: + if 'epoch' in metric: + epoch += 1 if len(metric) > 2: assert 'no_val_no_pbar' not in metric assert 'val_step_log_acc' not in metric @@ -288,11 +293,12 @@ def test_val_step_epoch_step_metrics(tmpdir): for metric_idx in range(0, len(trainer.dev_debugger.logged_metrics), batches + 1): batch_metrics = trainer.dev_debugger.logged_metrics[metric_idx: metric_idx + batches] epoch_metric = trainer.dev_debugger.logged_metrics[metric_idx + batches] + epoch = epoch_metric['epoch'] # make sure the metric was split for batch_metric in batch_metrics: - assert 'step_val_step_log_acc' in batch_metric - assert 'step_val_step_log_pbar_acc' in batch_metric + assert f'step_val_step_log_acc/epoch_{epoch}' in batch_metric + assert f'step_val_step_log_pbar_acc/epoch_{epoch}' in batch_metric # make sure the epoch split was correct assert 'epoch_val_step_log_acc' in epoch_metric @@ -421,11 +427,11 @@ def test_val_step_full_loop_result_dp(tmpdir): assert 'train_step_metric' in seen_keys assert 'train_step_end_metric' in seen_keys assert 'epoch_train_epoch_end_metric' in seen_keys - assert 'step_validation_step_metric' in seen_keys + assert 'step_validation_step_metric/epoch_0' in seen_keys assert 'epoch_validation_step_metric' in seen_keys assert 'validation_step_end_metric' in seen_keys assert 'validation_epoch_end_metric' in seen_keys - assert 'step_test_step_metric' in seen_keys + assert 'step_test_step_metric/epoch_2' in seen_keys assert 'epoch_test_step_metric' in seen_keys assert 'test_step_end_metric' in seen_keys assert 'test_epoch_end_metric' in seen_keys