diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index dd58dfc684..4eb89fad40 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -87,7 +87,7 @@ class LoggerConnector: self.trainer.logger.save() # track the logged metrics - self.logged_metrics = scalar_metrics + self.logged_metrics.update(scalar_metrics) self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics) def add_progress_bar_metrics(self, metrics): @@ -191,9 +191,8 @@ class LoggerConnector: return eval_loop_results - def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): - self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator, - early_stopping_accumulator, num_optimizers) + def on_train_epoch_end(self, epoch_output): + pass def log_train_epoch_end_metrics(self, epoch_output, @@ -413,7 +412,7 @@ class LoggerConnector: return gathered_epoch_outputs - def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): + def log_train_step_metrics(self, batch_idx, batch_output): # when metrics should be logged should_log_metrics = (batch_idx + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop if should_log_metrics or self.trainer.fast_dev_run: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4a4f8984da..088ed07d57 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -245,6 +245,11 @@ class EvaluationLoop(object): return eval_results def on_evaluation_batch_start(self, *args, **kwargs): + # reset the result of the PL module + model = self.trainer.get_model() + model._results = Result() + model._current_fx_name = 'evaluation_step' + if self.testing: self.trainer.call_hook('on_test_batch_start', *args, **kwargs) else: @@ -273,21 +278,28 @@ class EvaluationLoop(object): else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) - def log_step_metrics(self, output, batch_idx): + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return + results = self.trainer.get_model()._results + self.__log_result_step_metrics(results, batch_idx) + + # TODO: deprecate at 1.0 if isinstance(output, EvalResult): - step_log_metrics = output.batch_log_metrics - step_pbar_metrics = output.batch_pbar_metrics + self.__log_result_step_metrics(output, batch_idx) - if len(step_log_metrics) > 0: - # make the metrics appear as a different line in the same graph - metrics_by_epoch = {} - for k, v in step_log_metrics.items(): - metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v + def __log_result_step_metrics(self, output, batch_idx): + step_log_metrics = output.batch_log_metrics + step_pbar_metrics = output.batch_pbar_metrics - self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) + if len(step_log_metrics) > 0: + # make the metrics appear as a different line in the same graph + metrics_by_epoch = {} + for k, v in step_log_metrics.items(): + metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v - if len(step_pbar_metrics) > 0: - self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) + self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) + + if len(step_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa96550e7b..b292335ca2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -436,7 +436,7 @@ class Trainer( # clean up self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) - self.evaluation_loop.log_step_metrics(output, batch_idx) + self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx) # track epoch level metrics if output is not None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3efb51953d..797c4ab85c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -542,7 +542,7 @@ class TrainLoop: # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- - self.trainer.logger_connector.save_train_loop_metrics_to_loggers(batch_idx, batch_output) + self.trainer.logger_connector.log_train_step_metrics(batch_idx, batch_output) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -573,14 +573,17 @@ class TrainLoop: # progress global step according to grads progress self.increment_accumulated_grad_global_step() - # process epoch outputs - self.trainer.logger_connector.on_train_epoch_end( + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers ) + # hook + self.trainer.logger_connector.on_train_epoch_end(epoch_output) + # when no val loop is present or fast-dev-run still need to call checkpoints self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) @@ -704,6 +707,7 @@ class TrainLoop: batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} # track all metrics for callbacks + # TODO: is this needed? self.trainer.logger_connector.callback_metrics.update( {k: v for d in batch_callback_metrics for k, v in d.items() if v is not None} ) diff --git a/tests/trainer/test_trainining_step_no_dict_result.py b/tests/trainer/test_trainining_step_no_dict_result.py index 59b5835fd0..1580e93974 100644 --- a/tests/trainer/test_trainining_step_no_dict_result.py +++ b/tests/trainer/test_trainining_step_no_dict_result.py @@ -215,3 +215,59 @@ def test_training_step_dict(tmpdir): # epoch 1 assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2 assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3 + + +def test_validation_step_logging(tmpdir): + """ + Tests that only training_step can be used + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(DeterministicModel): + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('train_step_acc', acc, on_step=True, on_epoch=True) + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('val_step_acc', acc, on_step=True, on_epoch=True) + self.training_step_called = True + + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() + + model = TestModel() + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + row_log_interval=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + expected_logged_metrics = { + 'epoch', + 'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc', + 'val_step_acc/epoch_0', 'val_step_acc/epoch_1', + 'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1', + } + logged_metrics = set(trainer.logged_metrics.keys()) + assert expected_logged_metrics == logged_metrics + + # we don't want to enable val metrics during steps because it is not something that users should do + expected_cb_metrics = [ + 'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc', + ] + expected_cb_metrics = set(expected_cb_metrics) + callback_metrics = set(trainer.callback_metrics.keys()) + assert expected_cb_metrics == callback_metrics