passing batch outputs to on_train_batch_end (#4369)
* passing batch outputs to on_train_batch_end
* styling
* updating epoch end logic
* also condition on on_train_epoch_end hooks
* more readable
* pep8
* pep8
* readability suggestion accepted
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
* adding test_training_epoch_end_metrics_collection_on_override test
* fix formatting
* fix formatting
Co-authored-by: Swetha Mandava <smandava@nvidia.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
(cherry picked from commit 5fcca4e43b
)
This commit is contained in:
parent
da5ba50727
commit
c62f68c7cd
|
@ -241,13 +241,13 @@ class TrainLoop:
|
|||
self.trainer.call_hook("on_epoch_start")
|
||||
self.trainer.call_hook("on_train_epoch_start")
|
||||
|
||||
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
|
||||
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
|
||||
# hook
|
||||
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
|
||||
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)
|
||||
self.trainer.call_hook('on_batch_end')
|
||||
|
||||
# figure out what to track for epoch end
|
||||
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
|
||||
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)
|
||||
|
||||
# reset batch logger internals
|
||||
self.trainer.logger_connector.on_train_batch_end()
|
||||
|
@ -259,12 +259,27 @@ class TrainLoop:
|
|||
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
|
||||
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
|
||||
|
||||
# track the outputs to reduce at the end of the epoch
|
||||
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
|
||||
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
|
||||
sample_output = opt_outputs[-1]
|
||||
|
||||
# decide if we need to reduce at the end of the epoch automatically
|
||||
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
|
||||
hook_overridden = (
|
||||
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
|
||||
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
|
||||
)
|
||||
|
||||
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
|
||||
if not(hook_overridden or auto_reduce_tng_result):
|
||||
continue
|
||||
|
||||
# with 1 step (no tbptt) don't use a sequence at epoch end
|
||||
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
|
||||
opt_outputs = opt_outputs[0]
|
||||
|
||||
epoch_output[opt_idx].append(opt_outputs)
|
||||
|
||||
def get_optimizers_iterable(self):
|
||||
|
@ -548,17 +563,14 @@ class TrainLoop:
|
|||
if batch_output.signal == -1:
|
||||
break
|
||||
|
||||
# only track outputs when user implements training_epoch_end
|
||||
# otherwise we will build up unnecessary memory
|
||||
epoch_end_outputs = self.process_train_step_outputs(
|
||||
batch_end_outputs = self.process_train_step_outputs(
|
||||
batch_output.training_step_output_for_epoch_end,
|
||||
self.early_stopping_accumulator,
|
||||
self.checkpoint_accumulator,
|
||||
)
|
||||
|
||||
# hook
|
||||
# TODO: add outputs to batches
|
||||
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
|
||||
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE METRICS TO LOGGERS
|
||||
|
@ -896,7 +908,7 @@ class TrainLoop:
|
|||
# the training step outputs a list per optimizer. The list contains the outputs at each time step
|
||||
# when no TBPTT is used, then the list has 1 item per batch
|
||||
# when TBPTT IS used, then the list has n items (1 per time step)
|
||||
epoch_end_outputs = []
|
||||
batch_end_outputs = []
|
||||
for optimizer_idx_outputs in all_train_step_outputs:
|
||||
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
|
||||
if len(optimizer_idx_outputs) == 0:
|
||||
|
@ -911,14 +923,9 @@ class TrainLoop:
|
|||
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
|
||||
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])
|
||||
|
||||
# decide if we need to reduce at the end of the epoch automatically
|
||||
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
|
||||
batch_end_outputs.append(optimizer_idx_outputs)
|
||||
|
||||
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
|
||||
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
|
||||
epoch_end_outputs.append(optimizer_idx_outputs)
|
||||
|
||||
return epoch_end_outputs
|
||||
return batch_end_outputs
|
||||
|
||||
def prepare_optimizers(self):
|
||||
# in manual optimization we loop over all optimizers at once
|
||||
|
|
|
@ -18,7 +18,8 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
|
||||
|
@ -91,6 +92,58 @@ def test_training_epoch_end_metrics_collection(tmpdir):
|
|||
assert metrics[f'epoch_metric_{i}'] == i
|
||||
|
||||
|
||||
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
|
||||
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
|
||||
num_epochs = 1
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
self.len_outputs = 0
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
self.len_outputs = len(outputs[0])
|
||||
|
||||
class OverriddenModel(EvalModelTemplate):
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.num_train_batches = 0
|
||||
|
||||
def training_epoch_end(self, outputs): # Overridden
|
||||
pass
|
||||
return
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.num_train_batches += 1
|
||||
|
||||
class NotOverriddenModel(EvalModelTemplate):
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.num_train_batches = 0
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.num_train_batches += 1
|
||||
|
||||
overridden_model = OverriddenModel()
|
||||
not_overridden_model = NotOverriddenModel()
|
||||
|
||||
callback = LoggingCallback()
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
default_root_dir=tmpdir,
|
||||
overfit_batches=2,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
result = trainer.fit(overridden_model)
|
||||
assert callback.len_outputs == overridden_model.num_train_batches
|
||||
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
|
||||
|
||||
result = trainer.fit(not_overridden_model)
|
||||
assert callback.len_outputs == 0
|
||||
# outputs from on_train_batch_end should be empty
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_transfer_batch_hook():
|
||||
|
||||
|
|
Loading…
Reference in New Issue