From f42ea303c9e1beabb6db4c341a44e966458261d3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 29 Sep 2020 02:00:28 -0400 Subject: [PATCH] ref: enable self.log for eval loop metrics (#3715) * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end --- pytorch_lightning/core/lightning.py | 36 +++++- pytorch_lightning/core/step_result.py | 13 ++- .../trainer/connectors/logger_connector.py | 83 +++++++++++--- pytorch_lightning/trainer/evaluation_loop.py | 18 ++- pytorch_lightning/trainer/trainer.py | 14 ++- tests/trainer/test_eval_loop_logging_1_0.py | 103 ++++++++++++++++-- 6 files changed, 227 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d95210a52c..8966c3677e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -155,8 +155,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod value: Any, prog_bar: bool = False, logger: bool = True, - on_step: bool = True, - on_epoch: bool = True, + on_step: Union[None, bool] = None, + on_epoch: Union[None, bool] = None, reduce_fx: Callable = torch.mean, tbptt_reduce_fx: Callable = torch.mean, tbptt_pad_token: int = 0, @@ -190,8 +190,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod value: value name prog_bar: if True logs to the progress base logger: if True logs to the logger - on_step: if True logs the output of validation_step or test_step - on_epoch: if True, logs the output of the training loop aggregated + on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step + on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step reduce_fx: Torch.mean by default tbptt_reduce_fx: function to reduce on truncated back prop tbptt_pad_token: token to use for padding @@ -205,6 +205,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod if 'epoch_end' in self._current_fx_name and on_step: on_step = False + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + self._results.log( name, value, @@ -221,6 +225,30 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod sync_dist_group ) + def __auto_choose_log_on_step(self, on_step): + if on_step is None: + if self._current_fx_name in {'training_step', 'training_step_end'}: + on_step = True + elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end', + 'evaluation_epoch_end', 'training_epoch_end'}: + on_step = False + else: + on_step = False + + return on_step + + def __auto_choose_log_on_epoch(self, on_epoch): + if on_epoch is None: + if self._current_fx_name in {'training_step', 'training_step_end'}: + on_epoch = False + elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end', + 'evaluation_epoch_end', 'training_epoch_end'}: + on_epoch = True + else: + on_epoch = True + + return on_epoch + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index b527316ab6..7ef23aaee4 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -352,9 +352,14 @@ class Result(Dict): @classmethod def reduce_on_epoch_end(cls, outputs): # get the batch sizes for all outputs - batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1) + batch_sizes = [] + meta = {} + for x in outputs: + batch_sizes.append(x.get_batch_sizes()) + meta.update(x['meta']) + + batch_sizes = torch.stack(batch_sizes).view(-1) - meta = outputs[0]['meta'] result = cls() result = recursive_gather(outputs, result) recursive_stack(result) @@ -371,6 +376,8 @@ class Result(Dict): reduced_val = fx(result[k]) result[k] = reduced_val + else: + del result[k] result['meta'] = meta return result @@ -871,7 +878,7 @@ class EvalResult(Result): def weighted_mean(result, weights): - weights = weights.to(result.device) + weights = weights.to(result.device)[:result.size(0)] numerator = torch.dot(result.float(), weights.transpose(-1, 0).float()) result = numerator / weights.sum().float() return result diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 4eb89fad40..8f74d2d723 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -20,6 +20,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException from pprint import pprint from typing import Iterable +from copy import deepcopy class LoggerConnector: @@ -29,6 +30,7 @@ class LoggerConnector: self.callback_metrics = {} self.logged_metrics = {} self.progress_bar_metrics = {} + self.eval_loop_results = [] def on_trainer_init(self, logger, log_save_interval, row_log_interval): # logging @@ -100,11 +102,69 @@ class LoggerConnector: self.trainer.dev_debugger.track_pbar_metrics_history(metrics) def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode): - # TODO: merge both functions? - self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result) - return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode) + self._track_callback_metrics(eval_results, using_eval_result) + self._log_on_evaluation_epoch_end_metrics() - def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result): + # TODO: deprecate parts of this for 1.0 (when removing results) + self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode) + + # get the final loop results + eval_loop_results = self._get_evaluate_epoch_results(test_mode) + return eval_loop_results + + def _get_evaluate_epoch_results(self, test_mode): + # log results of test + if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + print('-' * 80) + for result_idx, results in enumerate(self.eval_loop_results): + print(f'DATALOADER:{result_idx} TEST RESULTS') + pprint(results) + print('-' * 80) + + results = self.eval_loop_results + + # clear mem + self.eval_loop_results = [] + return results + + def _log_on_evaluation_epoch_end_metrics(self): + step_metrics = self.trainer.evaluation_loop.step_metrics + + # clear mem + self.trainer.evaluation_loop.step_metrics = [] + + num_loaders = len(step_metrics) + + # process metrics per dataloader + for dl_idx, dl_metrics in enumerate(step_metrics): + if len(dl_metrics) == 0: + continue + + reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) + # make the keys 'k/dl' + reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders) + + # track the metrics + logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics() + pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() + self.logged_metrics.update(logger_metrics) + self.progress_bar_metrics.update(pbar_metrics) + + # enable the metrics to be monitored + self.callback_metrics.update(logger_metrics) + self.callback_metrics.update(pbar_metrics) + + # track the final results for the dataloader + self.eval_loop_results.append(deepcopy(self.callback_metrics)) + + def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): + if num_loaders == 1: + return metrics + + result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()} + return result + + def _track_callback_metrics(self, eval_results, using_eval_result): if len(eval_results) > 0 and eval_results[0] is None: return @@ -141,11 +201,10 @@ class LoggerConnector: flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode): + def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode): if self.trainer.running_sanity_check: return - eval_loop_results = [] if eval_results is not None and len(eval_results) > 0: # in eval, the user may return something at every validation step without final reduction @@ -179,17 +238,7 @@ class LoggerConnector: self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics) if len(dataloader_result_metrics) > 0: - eval_loop_results.append(dataloader_result_metrics) - - # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: - print('-' * 80) - for result_idx, results in enumerate(eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') - pprint(results) - print('-' * 80) - - return eval_loop_results + self.eval_loop_results.append(dataloader_result_metrics) def on_train_epoch_end(self, epoch_output): pass diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 088ed07d57..8053a2eea1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -23,6 +23,7 @@ class EvaluationLoop(object): self.trainer = trainer self.testing = False self.outputs = [] + self.step_metrics = [] self.predictions = None self.max_batches = None @@ -278,14 +279,21 @@ class EvaluationLoop(object): else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) - def log_evaluation_step_metrics(self, output, batch_idx): + def log_evaluation_step_metrics(self, batch, batch_idx): + results = self.trainer.get_model()._results + if len(results) == 1: + return None + + results.track_batch_size(len(batch)) + self.__log_result_step_metrics(results, batch_idx) + + return results + + # TODO: deprecate at 1.0 + def log_evaluation_step_metrics_legacy(self, output, batch_idx): if self.trainer.running_sanity_check: return - results = self.trainer.get_model()._results - self.__log_result_step_metrics(results, batch_idx) - - # TODO: deprecate at 1.0 if isinstance(output, EvalResult): self.__log_result_step_metrics(output, batch_idx) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b292335ca2..e5b04ab310 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -413,6 +413,7 @@ class Trainer( for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] + dl_step_metrics = [] dataloader = self.accelerator_backend.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] @@ -436,13 +437,22 @@ class Trainer( # clean up self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) - self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx) - # track epoch level metrics + # TODO: deprecate 1.0 + self.evaluation_loop.log_evaluation_step_metrics_legacy(output, batch_idx) + + # log step metrics + step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx) + + if step_metrics is not None: + dl_step_metrics.append(step_metrics) + + # track epoch level outputs if output is not None: dl_outputs.append(output) self.evaluation_loop.outputs.append(dl_outputs) + self.evaluation_loop.step_metrics.append(dl_step_metrics) # lightning module method eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) diff --git a/tests/trainer/test_eval_loop_logging_1_0.py b/tests/trainer/test_eval_loop_logging_1_0.py index ff544c21d5..42499dda10 100644 --- a/tests/trainer/test_eval_loop_logging_1_0.py +++ b/tests/trainer/test_eval_loop_logging_1_0.py @@ -17,14 +17,14 @@ def test__validation_step__log(tmpdir): def training_step(self, batch, batch_idx): acc = self.step(batch, batch_idx) acc = acc + batch_idx - self.log('train_step_acc', acc, on_step=True, on_epoch=True) + self.log('a', acc, on_step=True, on_epoch=True) self.training_step_called = True return acc def validation_step(self, batch, batch_idx): acc = self.step(batch, batch_idx) acc = acc + batch_idx - self.log('val_step_acc', acc, on_step=True, on_epoch=True) + self.log('b', acc, on_step=True, on_epoch=True) self.training_step_called = True def backward(self, trainer, loss, optimizer, optimizer_idx): @@ -46,18 +46,103 @@ def test__validation_step__log(tmpdir): # make sure all the metrics are available for callbacks expected_logged_metrics = { + 'a', + 'step_a', + 'epoch_a', + 'b', + 'step_b/epoch_0', + 'step_b/epoch_1', + 'b/epoch_0', + 'b/epoch_1', + 'epoch_b', 'epoch', - 'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc', - 'val_step_acc/epoch_0', 'val_step_acc/epoch_1', - 'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1', } logged_metrics = set(trainer.logged_metrics.keys()) assert expected_logged_metrics == logged_metrics # we don't want to enable val metrics during steps because it is not something that users should do - expected_cb_metrics = [ - 'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc', - ] - expected_cb_metrics = set(expected_cb_metrics) + # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + expected_cb_metrics = {'a', 'b', 'epoch_a', 'epoch_b', 'step_a'} + callback_metrics = set(trainer.callback_metrics.keys()) + assert expected_cb_metrics == callback_metrics + + +def test__validation_step__epoch_end__log(tmpdir): + """ + Tests that validation_step can log + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(DeterministicModel): + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('a', acc) + self.log('b', acc, on_step=True, on_epoch=True) + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('c', acc) + self.log('d', acc, on_step=True, on_epoch=True) + self.validation_step_called = True + + def validation_epoch_end(self, outputs): + self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True) + self.validation_epoch_end_called = True + + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() + + model = TestModel() + model.validation_step_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + row_log_interval=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + expected_logged_metrics = { + 'a', + 'b', + 'step_b', + 'epoch_b', + 'c', + 'd', + 'd/epoch_0', + 'd/epoch_1', + 'step_d/epoch_0', + 'step_d/epoch_1', + 'epoch_d', + 'e', + 'epoch_e', + 'epoch', + } + + logged_metrics = set(trainer.logged_metrics.keys()) + assert expected_logged_metrics == logged_metrics + + # we don't want to enable val metrics during steps because it is not something that users should do + expected_cb_metrics = { + 'a', + 'b', + 'step_b', + 'epoch_b', + 'c', + 'd', + 'epoch_d', + 'e', + 'epoch_e', + 'debug_epoch', + } + callback_metrics = set(trainer.callback_metrics.keys()) assert expected_cb_metrics == callback_metrics