passing batch outputs to on_train_batch_end (#4369)

* passing batch outputs to on_train_batch_end

* styling

* updating epoch end logic

* also condition on on_train_epoch_end hooks

* more readable

* pep8

* pep8

* readability suggestion accepted

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* adding test_training_epoch_end_metrics_collection_on_override test

* fix formatting

* fix formatting

Co-authored-by: Swetha Mandava <smandava@nvidia.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>

(cherry picked from commit 5fcca4e43b)
This commit is contained in:
Swetha Mandava 2021-01-26 14:01:46 +01:00 committed by Jirka Borovec
parent da5ba50727
commit c62f68c7cd
2 changed files with 79 additions and 19 deletions

View File

@ -241,13 +241,13 @@ class TrainLoop:
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
# hook
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)
self.trainer.call_hook('on_batch_end')
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)
# reset batch logger internals
self.trainer.logger_connector.on_train_batch_end()
@ -259,12 +259,27 @@ class TrainLoop:
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_val_dataloader(model)
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
sample_output = opt_outputs[-1]
# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
hook_overridden = (
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
)
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if not(hook_overridden or auto_reduce_tng_result):
continue
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)
def get_optimizers_iterable(self):
@ -548,17 +563,14 @@ class TrainLoop:
if batch_output.signal == -1:
break
# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.process_train_step_outputs(
batch_end_outputs = self.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
self.early_stopping_accumulator,
self.checkpoint_accumulator,
)
# hook
# TODO: add outputs to batches
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)
# -----------------------------------------
# SAVE METRICS TO LOGGERS
@ -896,7 +908,7 @@ class TrainLoop:
# the training step outputs a list per optimizer. The list contains the outputs at each time step
# when no TBPTT is used, then the list has 1 item per batch
# when TBPTT IS used, then the list has n items (1 per time step)
epoch_end_outputs = []
batch_end_outputs = []
for optimizer_idx_outputs in all_train_step_outputs:
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
if len(optimizer_idx_outputs) == 0:
@ -911,14 +923,9 @@ class TrainLoop:
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])
# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
batch_end_outputs.append(optimizer_idx_outputs)
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
epoch_end_outputs.append(optimizer_idx_outputs)
return epoch_end_outputs
return batch_end_outputs
def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once

View File

@ -18,7 +18,8 @@ from unittest.mock import MagicMock
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
@ -91,6 +92,58 @@ def test_training_epoch_end_metrics_collection(tmpdir):
assert metrics[f'epoch_metric_{i}'] == i
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
num_epochs = 1
class LoggingCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
self.len_outputs = 0
def on_train_epoch_end(self, trainer, pl_module, outputs):
self.len_outputs = len(outputs[0])
class OverriddenModel(EvalModelTemplate):
def on_train_epoch_start(self):
self.num_train_batches = 0
def training_epoch_end(self, outputs): # Overridden
pass
return
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1
class NotOverriddenModel(EvalModelTemplate):
def on_train_epoch_start(self):
self.num_train_batches = 0
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1
overridden_model = OverriddenModel()
not_overridden_model = NotOverriddenModel()
callback = LoggingCallback()
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)
result = trainer.fit(overridden_model)
assert callback.len_outputs == overridden_model.num_train_batches
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
result = trainer.fit(not_overridden_model)
assert callback.len_outputs == 0
# outputs from on_train_batch_end should be empty
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():