Fix Inconsistencies after introducing *_step_end and *_epoch_end (#1072)

* fix copy-paste errors after renaming *_end methods

* line too long

* update

* Update lightning.py

* Update lightning.py

* Update lightning.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Adrian Wälchli 2020-03-06 16:33:17 +01:00 committed by GitHub
parent 1a2726fa6b
commit 09482fb64f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 28 deletions

View File

@ -306,7 +306,17 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
(only if multiple val datasets used)
Return:
Dict or OrderedDict - passed to the validation_epoch_end
Dict or OrderedDict - passed to validation_epoch_end.
If you defined validation_step_end it will go to that first.
.. code-block:: python
# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
out = validation_step_end(out)
out = validation_epoch_end(out)
.. code-block:: python
@ -316,7 +326,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx)
EExamples:
Examples:
.. code-block:: python
# CASE 1: A single validation dataset
@ -359,12 +369,12 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
.. note:: When the validation_step is called, the model has been put in eval mode
and PyTorch gradients have been disabled. At the end of validation,
model goes back to training mode and gradients are enabled.
the model goes back to training mode and gradients are enabled.
"""
def validation_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
Use this when validating with dp or ddp2 because validation_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
@ -375,17 +385,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(batch_parts_outputs)
Args:
batch_parts_outputs: What you return in `training_step` for each batch part.
batch_parts_outputs: What you return in `validation_step` for each batch part.
Return:
dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
Dict or OrderedDict - passed to the validation_epoch_end
In this case you should define validation_step_end to perform those calculations.
@ -394,7 +401,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# WITHOUT validation_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
@ -405,14 +412,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# --------------
# with validation_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def validation_step_end(self, outputs):
def validation_epoch_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
@ -440,7 +447,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
train_outs.append(out
train_outs.append(out)
validation_epoch_end(val_outs)
Args:
@ -516,7 +523,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
test_outs = []
for test_batch in test_data:
out = test_step(train_batch)
test_outs.append(out
test_outs.append(out)
test_epoch_end(test_outs)
Args:
@ -527,7 +534,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
(only if multiple test datasets used)
Return:
Dict or OrderedDict - passed to the test_epoch_end
Dict or OrderedDict - passed to the test_step_end
.. code-block:: python
@ -579,14 +586,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
.. note:: If you don't need to validate you don't need to implement this method.
.. note:: When the validation_step is called, the model has been put in eval mode and
PyTorch gradients have been disabled. At the end of validation, model goes back
.. note:: When the test_step is called, the model has been put in eval mode and
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
to training mode and gradients are enabled.
"""
def test_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
Use this when testing with dp or ddp2 because test_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
@ -597,17 +604,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
test_step_end(batch_parts_outputs)
Args:
batch_parts_outputs: What you return in `training_step` for each batch part.
Return:
dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
Dict or OrderedDict - passed to the test_epoch_end
In this case you should define test_step_end to perform those calculations.
@ -616,7 +620,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# WITHOUT test_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
@ -627,7 +631,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
# --------------
# with test_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
@ -653,7 +657,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
def test_epoch_end(self, outputs):
"""
Called at end of test epoch with the output of all test_steps
Called at end of test epoch with the output of all test_steps.
.. code-block:: python
@ -666,7 +670,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
test_epoch_end(test_outs)
Args:
outputs (list): List of outputs you defined in test_step, or if there are multiple
dataloaders, a list containing a list of outputs for each dataloader