Fix Inconsistencies after introducing *_step_end and *_epoch_end (#1072)
* fix copy-paste errors after renaming *_end methods * line too long * update * Update lightning.py * Update lightning.py * Update lightning.py Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
1a2726fa6b
commit
09482fb64f
|
@ -306,7 +306,17 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
(only if multiple val datasets used)
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict - passed to the validation_epoch_end
|
||||
Dict or OrderedDict - passed to validation_epoch_end.
|
||||
If you defined validation_step_end it will go to that first.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pseudocode of order
|
||||
out = validation_step()
|
||||
if defined('validation_step_end'):
|
||||
out = validation_step_end(out)
|
||||
out = validation_epoch_end(out)
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -316,7 +326,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
# if you have multiple val dataloaders:
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx)
|
||||
|
||||
EExamples:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# CASE 1: A single validation dataset
|
||||
|
@ -359,12 +369,12 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
.. note:: When the validation_step is called, the model has been put in eval mode
|
||||
and PyTorch gradients have been disabled. At the end of validation,
|
||||
model goes back to training mode and gradients are enabled.
|
||||
the model goes back to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def validation_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because training_step will operate
|
||||
Use this when validating with dp or ddp2 because validation_step will operate
|
||||
on only part of the batch. However, this is still optional
|
||||
and only needed for things like softmax or NCE loss.
|
||||
|
||||
|
@ -375,17 +385,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# pseudocode
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
|
||||
validation_step_end(batch_parts_outputs)
|
||||
|
||||
Args:
|
||||
batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||
batch_parts_outputs: What you return in `validation_step` for each batch part.
|
||||
|
||||
Return:
|
||||
dictionary with loss key and optional log, progress keys:
|
||||
- loss -> tensor scalar [REQUIRED]
|
||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
Dict or OrderedDict - passed to the validation_epoch_end
|
||||
|
||||
In this case you should define validation_step_end to perform those calculations.
|
||||
|
||||
|
@ -394,7 +401,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# WITHOUT validation_step_end
|
||||
# if used in DP or DDP2, this batch is 1/num_gpus large
|
||||
def training_step(self, batch, batch_idx):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
|
@ -405,14 +412,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# --------------
|
||||
# with validation_step_end to do softmax over the full batch
|
||||
def training_step(self, batch, batch_idx):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
out = self.forward(x)
|
||||
return {'out': out}
|
||||
|
||||
def validation_step_end(self, outputs):
|
||||
def validation_epoch_end(self, outputs):
|
||||
# this out is now the full size of the batch
|
||||
out = outputs['out']
|
||||
|
||||
|
@ -440,7 +447,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
val_outs = []
|
||||
for val_batch in val_data:
|
||||
out = validation_step(train_batch)
|
||||
train_outs.append(out
|
||||
train_outs.append(out)
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
Args:
|
||||
|
@ -516,7 +523,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
test_outs = []
|
||||
for test_batch in test_data:
|
||||
out = test_step(train_batch)
|
||||
test_outs.append(out
|
||||
test_outs.append(out)
|
||||
test_epoch_end(test_outs)
|
||||
|
||||
Args:
|
||||
|
@ -527,7 +534,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
(only if multiple test datasets used)
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict - passed to the test_epoch_end
|
||||
Dict or OrderedDict - passed to the test_step_end
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -579,14 +586,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
.. note:: If you don't need to validate you don't need to implement this method.
|
||||
|
||||
.. note:: When the validation_step is called, the model has been put in eval mode and
|
||||
PyTorch gradients have been disabled. At the end of validation, model goes back
|
||||
.. note:: When the test_step is called, the model has been put in eval mode and
|
||||
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
||||
to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def test_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because training_step will operate
|
||||
Use this when testing with dp or ddp2 because test_step will operate
|
||||
on only part of the batch. However, this is still optional
|
||||
and only needed for things like softmax or NCE loss.
|
||||
|
||||
|
@ -597,17 +604,14 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# pseudocode
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
|
||||
test_step_end(batch_parts_outputs)
|
||||
|
||||
Args:
|
||||
batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||
|
||||
Return:
|
||||
dictionary with loss key and optional log, progress keys:
|
||||
- loss -> tensor scalar [REQUIRED]
|
||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
Dict or OrderedDict - passed to the test_epoch_end
|
||||
|
||||
In this case you should define test_step_end to perform those calculations.
|
||||
|
||||
|
@ -616,7 +620,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# WITHOUT test_step_end
|
||||
# if used in DP or DDP2, this batch is 1/num_gpus large
|
||||
def training_step(self, batch, batch_idx):
|
||||
def test_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
|
@ -627,7 +631,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# --------------
|
||||
# with test_step_end to do softmax over the full batch
|
||||
def training_step(self, batch, batch_idx):
|
||||
def test_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
|
@ -653,7 +657,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
def test_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at end of test epoch with the output of all test_steps
|
||||
Called at end of test epoch with the output of all test_steps.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -666,7 +670,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
test_epoch_end(test_outs)
|
||||
|
||||
Args:
|
||||
|
||||
outputs (list): List of outputs you defined in test_step, or if there are multiple
|
||||
dataloaders, a list containing a list of outputs for each dataloader
|
||||
|
||||
|
|
Loading…
Reference in New Issue