From c62f68c7cd25e644986caa7a8bd4c4ffb066de86 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Tue, 26 Jan 2021 14:01:46 +0100 Subject: [PATCH] passing batch outputs to on_train_batch_end (#4369) * passing batch outputs to on_train_batch_end * styling * updating epoch end logic * also condition on on_train_epoch_end hooks * more readable * pep8 * pep8 * readability suggestion accepted Co-authored-by: Jirka Borovec * adding test_training_epoch_end_metrics_collection_on_override test * fix formatting * fix formatting Co-authored-by: Swetha Mandava Co-authored-by: Jirka Borovec Co-authored-by: Sean Naren Co-authored-by: Roger Shieh (cherry picked from commit 5fcca4e43b243cd9fdb08050b285fb052856f13b) --- pytorch_lightning/trainer/training_loop.py | 43 ++++++++++------- tests/models/test_hooks.py | 55 +++++++++++++++++++++- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fd6c7c8758..436e605aa4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -241,13 +241,13 @@ class TrainLoop: self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): # hook - self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') # figure out what to track for epoch end - self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) + self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() @@ -259,12 +259,27 @@ class TrainLoop: if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) - def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + # track the outputs to reduce at the end of the epoch - for opt_idx, opt_outputs in enumerate(epoch_end_outputs): + for opt_idx, opt_outputs in enumerate(batch_end_outputs): + sample_output = opt_outputs[-1] + + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.get_model()) or + is_overridden("on_train_epoch_end", model=self.trainer.get_model()) + ) + + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not(hook_overridden or auto_reduce_tng_result): + continue + # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): opt_outputs = opt_outputs[0] + epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): @@ -548,17 +563,14 @@ class TrainLoop: if batch_output.signal == -1: break - # only track outputs when user implements training_epoch_end - # otherwise we will build up unnecessary memory - epoch_end_outputs = self.process_train_step_outputs( + batch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) - # hook # TODO: add outputs to batches - self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS @@ -896,7 +908,7 @@ class TrainLoop: # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) - epoch_end_outputs = [] + batch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: @@ -911,14 +923,9 @@ class TrainLoop: if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate(sample_output["checkpoint_on"]) - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + batch_end_outputs.append(optimizer_idx_outputs) - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result: - epoch_end_outputs.append(optimizer_idx_outputs) - - return epoch_end_outputs + return batch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9fae1a7c20..9f9b03db4c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -18,7 +18,8 @@ from unittest.mock import MagicMock import pytest import torch -from pytorch_lightning import Trainer + +from pytorch_lightning import Trainer, Callback from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator from pytorch_lightning.trainer.states import TrainerState from tests.base import BoringModel, EvalModelTemplate, RandomDataset @@ -91,6 +92,58 @@ def test_training_epoch_end_metrics_collection(tmpdir): assert metrics[f'epoch_metric_{i}'] == i +def test_training_epoch_end_metrics_collection_on_override(tmpdir): + """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ + num_epochs = 1 + + class LoggingCallback(Callback): + + def on_train_epoch_end(self, trainer, pl_module): + self.len_outputs = 0 + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.len_outputs = len(outputs[0]) + + class OverriddenModel(EvalModelTemplate): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def training_epoch_end(self, outputs): # Overridden + pass + return + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + class NotOverriddenModel(EvalModelTemplate): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + overridden_model = OverriddenModel() + not_overridden_model = NotOverriddenModel() + + callback = LoggingCallback() + trainer = Trainer( + max_epochs=num_epochs, + default_root_dir=tmpdir, + overfit_batches=2, + callbacks=[callback], + ) + + result = trainer.fit(overridden_model) + assert callback.len_outputs == overridden_model.num_train_batches + # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden + + result = trainer.fit(not_overridden_model) + assert callback.len_outputs == 0 + # outputs from on_train_batch_end should be empty + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_transfer_batch_hook():