Improved docs for LightningModule (#1389)
* improve docs for LightingModule * fix typos * revert a doctest * more fixes
This commit is contained in:
parent
5f6be4dd53
commit
26cb5f6817
|
@ -74,17 +74,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def print(self, *args, **kwargs) -> None:
|
def print(self, *args, **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Prints only from process 0. Use this in any distributed mode to log only once
|
Prints only from process 0. Use this in any distributed mode to log only once.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (object): The thing to print
|
*args: The thing to print. Will be passed to Python's built-in print function.
|
||||||
|
**kwargs: Will be passed to Python's built-in print function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
Examples:
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# example if we were using this model as a feature extractor
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.print(x, 'in loader')
|
self.print(x, 'in forward')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.trainer.proc_rank == 0:
|
if self.trainer.proc_rank == 0:
|
||||||
|
@ -93,15 +94,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Same as torch.nn.Module.forward(), however in Lightning you want this to define
|
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
|
||||||
the operations you want to use for prediction (ie: on a server or as a feature extractor).
|
the operations you want to use for prediction (i.e.: on a server or as a feature extractor).
|
||||||
|
|
||||||
Normally you'd call self() from your training_step() method.
|
Normally you'd call ``self()`` from your :meth:`training_step` method.
|
||||||
This makes it easy to write a complex system for training with the outputs
|
This makes it easy to write a complex system for training with the outputs
|
||||||
you'd want in a prediction setting.
|
you'd want in a prediction setting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (tensor): Whatever you decide to define in the forward method
|
*args: Whatever you decide to pass into the forward method.
|
||||||
|
**kwargs: Keyword arguments are also possible.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Predicted output
|
Predicted output
|
||||||
|
@ -142,21 +144,23 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
def training_step(self, *args, **kwargs) -> Union[
|
def training_step(self, *args, **kwargs) -> Union[
|
||||||
int, Dict[str, Union[Tensor, Dict[str, Tensor]]]
|
int, Dict[str, Union[Tensor, Dict[str, Tensor]]]
|
||||||
]:
|
]:
|
||||||
r"""return loss, dict with metrics for tqdm
|
r"""
|
||||||
|
Here you compute and return the training loss and some additional metrics for e.g.
|
||||||
|
the progress bar or logger.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||||
dataloader. A tensor, tuple or list
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||||
batch_idx (int): Integer displaying index of this batch
|
batch_idx (int): Integer displaying index of this batch
|
||||||
optimizer_idx (int): If using multiple optimizers, this argument will also be present.
|
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||||
hiddens(:`Tensor <https://pytorch.org/docs/stable/tensors.html>`_):
|
hiddens(:class:`~torch.Tensor`): Passed in if
|
||||||
Passed in if truncated_bptt_steps > 0.
|
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict with loss key and optional log, progress keys
|
Dict with loss key and optional log or progress bar keys.
|
||||||
if implementing training_step, return whatever you need in that step:
|
When implementing :meth:`training_step`, return whatever you need in that step:
|
||||||
|
|
||||||
- loss -> tensor scalar [REQUIRED]
|
- loss -> tensor scalar **REQUIRED**
|
||||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||||
|
|
||||||
|
@ -188,11 +192,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
If you define multiple optimizers, this step will be called with an additional
|
If you define multiple optimizers, this step will be called with an additional
|
||||||
`optimizer_idx` param.
|
``optimizer_idx`` parameter.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Multiple optimizers (ie: GANs)
|
# Multiple optimizers (e.g.: GANs)
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
# do training_step with encoder
|
# do training_step with encoder
|
||||||
|
@ -207,7 +211,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
# Truncated back-propagation through time
|
# Truncated back-propagation through time
|
||||||
def training_step(self, batch, batch_idx, hiddens):
|
def training_step(self, batch, batch_idx, hiddens):
|
||||||
# hiddens are the hiddens from the previous truncated backprop step
|
# hiddens are the hidden states from the previous truncated backprop step
|
||||||
...
|
...
|
||||||
out, hiddens = self.lstm(data, hiddens)
|
out, hiddens = self.lstm(data, hiddens)
|
||||||
...
|
...
|
||||||
|
@ -221,27 +225,26 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
if you want to break out of the current training epoch early.
|
if you want to break out of the current training epoch early.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
The presented loss value in progress bar is smooth (average) over last values,
|
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
||||||
so it differs from values set in train/validation step.
|
so it differs from the actual loss returned in train/validation step.
|
||||||
"""
|
"""
|
||||||
warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer')
|
warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer')
|
||||||
|
|
||||||
def training_end(self, *args, **kwargs):
|
def training_end(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Warnings:
|
Warnings:
|
||||||
Deprecated in v0.7.0. use training_step_end instead
|
Deprecated in v0.7.0. Use :meth:`training_step_end` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def training_epoch_end(
|
def training_epoch_end(
|
||||||
self,
|
self,
|
||||||
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
||||||
) -> Dict[str, Dict[str, Tensor]]:
|
) -> Dict[str, Dict[str, Tensor]]:
|
||||||
"""Called at the end of training epoch with the outputs of all training_steps
|
"""Called at the end of the training epoch with the outputs of all training steps.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# the pseudocode for these calls
|
# the pseudocode for these calls
|
||||||
|
|
||||||
train_outs = []
|
train_outs = []
|
||||||
for train_batch in train_data:
|
for train_batch in train_data:
|
||||||
out = training_step(train_batch)
|
out = training_step(train_batch)
|
||||||
|
@ -249,24 +252,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
training_epoch_end(train_outs)
|
training_epoch_end(train_outs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs: List of outputs you defined in training_step, or if there are multiple
|
outputs: List of outputs you defined in :meth:`training_step`, or if there are
|
||||||
dataloaders, a list containing a list of outputs for each dataloader
|
multiple dataloaders, a list containing a list of outputs for each dataloader.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict
|
Dict or OrderedDict.
|
||||||
May contain the following optional keys:
|
May contain the following optional keys:
|
||||||
|
|
||||||
- log (metrics to be added to the logger ; only tensors)
|
- log (metrics to be added to the logger; only tensors)
|
||||||
- any metric used in a callback (e.g. early stopping).
|
- any metric used in a callback (e.g. early stopping).
|
||||||
|
|
||||||
.. note:: If this method is not overridden, this won't be called.
|
Note:
|
||||||
|
If this method is not overridden, this won't be called.
|
||||||
|
|
||||||
- The outputs here are strictly for logging or progress bar.
|
- The outputs here are strictly for logging or progress bar.
|
||||||
- If you don't need to display anything, don't return anything.
|
- If you don't need to display anything, don't return anything.
|
||||||
- If you want to manually set current step, you can specify the 'step' key in the 'log' Dict
|
- If you want to manually set current step, you can specify the 'step' key in the 'log' dict.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
With a single dataloader
|
With a single dataloader:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -283,7 +287,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
With multiple dataloaders, ``outputs`` will be a list of lists. The outer list contains
|
||||||
one entry per dataloader, while the inner list contains the individual outputs of
|
one entry per dataloader, while the inner list contains the individual outputs of
|
||||||
each training step for that dataloader.
|
each training step for that dataloader.
|
||||||
|
|
||||||
|
@ -310,11 +314,12 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
str, Union[Tensor, Dict[str, Tensor]]
|
str, Union[Tensor, Dict[str, Tensor]]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Use this when training with dp or ddp2 because training_step will operate
|
Use this when training with dp or ddp2 because :meth:`training_step`
|
||||||
on only part of the batch. However, this is still optional
|
will operate on only part of the batch. However, this is still optional
|
||||||
and only needed for things like softmax or NCE loss.
|
and only needed for things like softmax or NCE loss.
|
||||||
|
|
||||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
Note:
|
||||||
|
If you later switch to ddp or some other mode, this will still be called
|
||||||
so that you don't have to change your code
|
so that you don't have to change your code
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -328,13 +333,12 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
batch_parts_outputs: What you return in `training_step` for each batch part.
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dictionary with loss key and optional log, progress keys:
|
Dict with loss key and optional log or progress bar keys.
|
||||||
- loss -> tensor scalar [REQUIRED]
|
|
||||||
|
- loss -> tensor scalar **REQUIRED**
|
||||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||||
|
|
||||||
In this case you should define training_step_end to perform those calculations.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -366,13 +370,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
loss = nce_loss(loss)
|
loss = nce_loss(loss)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
.. seealso::
|
See Also:
|
||||||
see the :ref:`multi-gpu-training` guide for more details.
|
See the :ref:`multi-gpu-training` guide for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]:
|
def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Operate on a single batch of data from the validation set
|
Operates on a single batch of data from the validation set.
|
||||||
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -385,15 +389,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
validation_epoch_end(val_outs)
|
validation_epoch_end(val_outs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||||
dataloader. A tensor, tuple or list
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||||
batch_idx (int): The index of this batch
|
batch_idx (int): The index of this batch
|
||||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||||
(only if multiple val datasets used)
|
(only if multiple val datasets used)
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict - passed to validation_epoch_end.
|
Dict or OrderedDict - passed to :meth:`validation_epoch_end`.
|
||||||
If you defined validation_step_end it will go to that first.
|
If you defined :meth:`validation_step_end` it will go to that first.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -434,7 +438,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||||
|
|
||||||
# all optional...
|
# all optional...
|
||||||
# return whatever you need for the collation function validation_end
|
# return whatever you need for the collation function validation_epoch_end
|
||||||
output = OrderedDict({
|
output = OrderedDict({
|
||||||
'val_loss': loss_val,
|
'val_loss': loss_val,
|
||||||
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
||||||
|
@ -451,21 +455,24 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
def validation_step(self, batch, batch_idx, dataset_idx):
|
def validation_step(self, batch, batch_idx, dataset_idx):
|
||||||
# dataset_idx tells you which dataset this is.
|
# dataset_idx tells you which dataset this is.
|
||||||
|
|
||||||
.. note:: If you don't need to validate you don't need to implement this method.
|
Note:
|
||||||
|
If you don't need to validate you don't need to implement this method.
|
||||||
|
|
||||||
.. note:: When the validation_step is called, the model has been put in eval mode
|
Note:
|
||||||
|
When the :meth:`validation_step` is called, the model has been put in eval mode
|
||||||
and PyTorch gradients have been disabled. At the end of validation,
|
and PyTorch gradients have been disabled. At the end of validation,
|
||||||
the model goes back to training mode and gradients are enabled.
|
the model goes back to training mode and gradients are enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
|
def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
"""
|
"""
|
||||||
Use this when validating with dp or ddp2 because validation_step will operate
|
Use this when validating with dp or ddp2 because :meth:`validation_step`
|
||||||
on only part of the batch. However, this is still optional
|
will operate on only part of the batch. However, this is still optional
|
||||||
and only needed for things like softmax or NCE loss.
|
and only needed for things like softmax or NCE loss.
|
||||||
|
|
||||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
Note:
|
||||||
so that you don't have to change your code
|
If you later switch to ddp or some other mode, this will still be called
|
||||||
|
so that you don't have to change your code.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -475,12 +482,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
validation_step_end(batch_parts_outputs)
|
validation_step_end(batch_parts_outputs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_parts_outputs: What you return in `validation_step` for each batch part.
|
batch_parts_outputs: What you return in :meth:`validation_step`
|
||||||
|
for each batch part.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict - passed to the validation_epoch_end
|
Dict or OrderedDict - passed to the :meth:`validation_epoch_end` method.
|
||||||
|
|
||||||
In this case you should define validation_step_end to perform those calculations.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -513,14 +519,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
loss = nce_loss(loss)
|
loss = nce_loss(loss)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
.. seealso::
|
See Also:
|
||||||
see the :ref:`multi-gpu-training` guide for more details.
|
See the :ref:`multi-gpu-training` guide for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validation_end(self, outputs):
|
def validation_end(self, outputs):
|
||||||
"""
|
"""
|
||||||
Warnings:
|
Warnings:
|
||||||
Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0
|
Deprecated in v0.7.0. Use :meth:`validation_epoch_end` instead.
|
||||||
|
Will be removed in 1.0.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validation_epoch_end(
|
def validation_epoch_end(
|
||||||
|
@ -528,12 +535,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
||||||
) -> Dict[str, Dict[str, Tensor]]:
|
) -> Dict[str, Dict[str, Tensor]]:
|
||||||
"""
|
"""
|
||||||
Called at end of validation epoch with the outputs of all validation_steps
|
Called at the end of the validation epoch with the outputs of all validation steps.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# the pseudocode for these calls
|
# the pseudocode for these calls
|
||||||
|
|
||||||
val_outs = []
|
val_outs = []
|
||||||
for val_batch in val_data:
|
for val_batch in val_data:
|
||||||
out = validation_step(train_batch)
|
out = validation_step(train_batch)
|
||||||
|
@ -541,24 +547,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
validation_epoch_end(val_outs)
|
validation_epoch_end(val_outs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs: List of outputs you defined in validation_step, or if there are multiple
|
outputs: List of outputs you defined in :meth:`validation_step`, or if there
|
||||||
dataloaders, a list containing a list of outputs for each dataloader
|
are multiple dataloaders, a list containing a list of outputs for each dataloader.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict
|
Dict or OrderedDict.
|
||||||
May have the following optional keys:
|
May have the following optional keys:
|
||||||
|
|
||||||
- progress_bar (dict for progress bar display ; only tensors)
|
- progress_bar (dict for progress bar display; only tensors)
|
||||||
- log (dict of metrics to add to logger ; only tensors).
|
- log (dict of metrics to add to logger; only tensors).
|
||||||
|
|
||||||
.. note:: If you didn't define a validation_step, this won't be called.
|
Note:
|
||||||
|
If you didn't define a :meth:`validation_step`, this won't be called.
|
||||||
|
|
||||||
- The outputs here are strictly for logging or progress bar.
|
- The outputs here are strictly for logging or progress bar.
|
||||||
- If you don't need to display anything, don't return anything.
|
- If you don't need to display anything, don't return anything.
|
||||||
- If you want to manually set current step, you can specify the 'step' key in the 'log' Dict
|
- If you want to manually set current step, you can specify the 'step' key in the 'log' dict.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
With a single dataloader
|
With a single dataloader:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -604,29 +611,29 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
|
def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Operate on a single batch of data from the test set
|
Operates on a single batch of data from the test set.
|
||||||
In this step you'd normally generate examples or calculate anything of interest
|
In this step you'd normally generate examples or calculate anything of interest
|
||||||
such as accuracy.
|
such as accuracy.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# the pseudocode for these calls
|
# the pseudocode for these calls
|
||||||
|
|
||||||
test_outs = []
|
test_outs = []
|
||||||
for test_batch in test_data:
|
for test_batch in test_data:
|
||||||
out = test_step(train_batch)
|
out = test_step(test_batch)
|
||||||
test_outs.append(out)
|
test_outs.append(out)
|
||||||
test_epoch_end(test_outs)
|
test_epoch_end(test_outs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||||
dataloader. A tensor, tuple or list
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||||
batch_idx (int): The index of this batch
|
batch_idx (int): The index of this batch.
|
||||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||||
(only if multiple test datasets used)
|
(only if multiple test datasets used).
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict - passed to the test_step_end
|
Dict or OrderedDict - passed to the :meth:`test_epoch_end` method.
|
||||||
|
If you defined :meth:`test_step_end` it will go to that first.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -658,7 +665,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||||
|
|
||||||
# all optional...
|
# all optional...
|
||||||
# return whatever you need for the collation function validation_end
|
# return whatever you need for the collation function test_epoch_end
|
||||||
output = OrderedDict({
|
output = OrderedDict({
|
||||||
'val_loss': loss_val,
|
'val_loss': loss_val,
|
||||||
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
||||||
|
@ -667,30 +674,33 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
# return an optional dict
|
# return an optional dict
|
||||||
return output
|
return output
|
||||||
|
|
||||||
If you pass in multiple validation datasets, validation_step will have an additional
|
If you pass in multiple validation datasets, :meth:`test_step` will have an additional
|
||||||
argument.
|
argument.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# CASE 2: multiple validation datasets
|
# CASE 2: multiple test datasets
|
||||||
def test_step(self, batch, batch_idx, dataset_idx):
|
def test_step(self, batch, batch_idx, dataset_idx):
|
||||||
# dataset_idx tells you which dataset this is.
|
# dataset_idx tells you which dataset this is.
|
||||||
|
|
||||||
.. note:: If you don't need to validate you don't need to implement this method.
|
Note:
|
||||||
|
If you don't need to validate you don't need to implement this method.
|
||||||
|
|
||||||
.. note:: When the test_step is called, the model has been put in eval mode and
|
Note:
|
||||||
|
When the :meth:`test_step` is called, the model has been put in eval mode and
|
||||||
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
||||||
to training mode and gradients are enabled.
|
to training mode and gradients are enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
|
def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
"""
|
"""
|
||||||
Use this when testing with dp or ddp2 because test_step will operate
|
Use this when testing with dp or ddp2 because :meth:`test_step` will operate
|
||||||
on only part of the batch. However, this is still optional
|
on only part of the batch. However, this is still optional
|
||||||
and only needed for things like softmax or NCE loss.
|
and only needed for things like softmax or NCE loss.
|
||||||
|
|
||||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
Note:
|
||||||
so that you don't have to change your code
|
If you later switch to ddp or some other mode, this will still be called
|
||||||
|
so that you don't have to change your code.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -700,12 +710,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
test_step_end(batch_parts_outputs)
|
test_step_end(batch_parts_outputs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_parts_outputs: What you return in `training_step` for each batch part.
|
batch_parts_outputs: What you return in :meth:`test_step` for each batch part.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict - passed to the test_epoch_end
|
Dict or OrderedDict - passed to the :meth:`test_epoch_end`.
|
||||||
|
|
||||||
In this case you should define test_step_end to perform those calculations.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -738,14 +746,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
loss = nce_loss(loss)
|
loss = nce_loss(loss)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
.. seealso::
|
See Also:
|
||||||
see the :ref:`multi-gpu-training` guide for more details.
|
See the :ref:`multi-gpu-training` guide for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_end(self, outputs):
|
def test_end(self, outputs):
|
||||||
"""
|
"""
|
||||||
Warnings:
|
Warnings:
|
||||||
Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0
|
Deprecated in v0.7.0. Use :meth:`test_epoch_end` instead.
|
||||||
|
Will be removed in 1.0.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_epoch_end(
|
def test_epoch_end(
|
||||||
|
@ -753,12 +762,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
|
||||||
) -> Dict[str, Dict[str, Tensor]]:
|
) -> Dict[str, Dict[str, Tensor]]:
|
||||||
"""
|
"""
|
||||||
Called at end of test epoch with the output of all test_steps.
|
Called at the end of a test epoch with the output of all test steps.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# the pseudocode for these calls
|
# the pseudocode for these calls
|
||||||
|
|
||||||
test_outs = []
|
test_outs = []
|
||||||
for test_batch in test_data:
|
for test_batch in test_data:
|
||||||
out = test_step(test_batch)
|
out = test_step(test_batch)
|
||||||
|
@ -766,22 +774,24 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
test_epoch_end(test_outs)
|
test_epoch_end(test_outs)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs: List of outputs you defined in test_step, or if there are multiple
|
outputs: List of outputs you defined in :meth:`test_step_end`, or if there
|
||||||
dataloaders, a list containing a list of outputs for each dataloader
|
are multiple dataloaders, a list containing a list of outputs for each dataloader
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict or OrderedDict (dict): Dict has the following optional keys:
|
Dict or OrderedDict: Dict has the following optional keys:
|
||||||
progress_bar -> Dict for progress bar display. Must have only tensors
|
|
||||||
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
|
||||||
|
|
||||||
.. note:: If you didn't define a test_step, this won't be called.
|
- progress_bar -> Dict for progress bar display. Must have only tensors.
|
||||||
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If you didn't define a :meth:`test_step`, this won't be called.
|
||||||
|
|
||||||
- The outputs here are strictly for logging or progress bar.
|
- The outputs here are strictly for logging or progress bar.
|
||||||
- If you don't need to display anything, don't return anything.
|
- If you don't need to display anything, don't return anything.
|
||||||
- If you want to manually set current step, specify it with the 'step' key in the 'log' Dict
|
- If you want to manually set current step, specify it with the 'step' key in the 'log' Dict
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
With a single dataloader
|
With a single dataloader:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -834,13 +844,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
Override to init DDP in your own way or with your own wrapper.
|
Override to init DDP in your own way or with your own wrapper.
|
||||||
The only requirements are that:
|
The only requirements are that:
|
||||||
|
|
||||||
1. On a validation batch the call goes to model.validation_step.
|
1. On a validation batch the call goes to ``model.validation_step``.
|
||||||
2. On a training batch the call goes to model.training_step.
|
2. On a training batch the call goes to ``model.training_step``.
|
||||||
3. On a testing batch, the call goes to model.test_step
|
3. On a testing batch, the call goes to ``model.test_step``.+
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: the LightningModule currently being optimized
|
model: the :class:`LightningModule` currently being optimized.
|
||||||
device_ids: the list of GPU ids
|
device_ids: the list of GPU ids.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
DDP wrapped model
|
DDP wrapped model
|
||||||
|
@ -868,14 +878,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def init_ddp_connection(self, proc_rank: int, world_size: int) -> None:
|
def init_ddp_connection(self, proc_rank: int, world_size: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
Override to define your custom way of setting up a distributed environment.
|
Override to define your custom way of setting up a distributed environment.
|
||||||
|
|
||||||
Lightning's implementation uses env:// init by default and sets the first node as root.
|
Lightning's implementation uses ``env://`` init by default and sets the first node as root.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
proc_rank: The current process rank within the node.
|
proc_rank: The current process rank within the node.
|
||||||
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
|
world_size: Number of GPUs being use across all nodes (num_nodes * num_gpus).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -952,13 +961,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
amp_level: str
|
amp_level: str
|
||||||
) -> Tuple['LightningModule', List[Optimizer]]:
|
) -> Tuple['LightningModule', List[Optimizer]]:
|
||||||
r"""
|
r"""
|
||||||
Override to init AMP your own way
|
Override to init AMP your own way.
|
||||||
Must return a model and list of optimizers
|
Must return a model and list of optimizers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
amp: pointer to amp library object
|
amp: pointer to amp library object.
|
||||||
model: pointer to current lightningModule
|
model: pointer to current :class:`LightningModule`.
|
||||||
optimizers: list of optimizers passed in configure_optimizers()
|
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
|
||||||
amp_level: AMP mode chosen ('O1', 'O2', etc...)
|
amp_level: AMP mode chosen ('O1', 'O2', etc...)
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
|
@ -988,18 +997,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
||||||
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
||||||
|
|
||||||
Return: any of these 6 options:
|
Return:
|
||||||
|
Any of these 6 options.
|
||||||
|
|
||||||
- Single optimizer.
|
- Single optimizer.
|
||||||
- List or Tuple - List of optimizers.
|
- List or Tuple - List of optimizers.
|
||||||
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
|
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
|
||||||
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
|
- Dictionary, with an 'optimizer' key and (optionally) a 'lr_scheduler' key.
|
||||||
- Tuple of dictionaries as described, with an optional `frequency` key.
|
- Tuple of dictionaries as described, with an optional 'frequency' key.
|
||||||
- None - Fit will run without any optimizer.
|
- None - Fit will run without any optimizer.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The `frequency` value is an int corresponding to the number of sequential batches
|
The 'frequency' value is an int corresponding to the number of sequential batches
|
||||||
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
|
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
|
||||||
There is difference between passing multiple optimizers in a list,
|
There is a difference between passing multiple optimizers in a list,
|
||||||
and passing multiple optimizers in dictionaries with a frequency of 1:
|
and passing multiple optimizers in dictionaries with a frequency of 1:
|
||||||
In the former case, all optimizers will operate on the given batch in each optimization step.
|
In the former case, all optimizers will operate on the given batch in each optimization step.
|
||||||
In the latter, only one optimizer will operate on the given batch at every step.
|
In the latter, only one optimizer will operate on the given batch at every step.
|
||||||
|
@ -1012,20 +1023,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
opt = Adam(self.parameters(), lr=1e-3)
|
opt = Adam(self.parameters(), lr=1e-3)
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
# multiple optimizer case (eg: GAN)
|
# multiple optimizer case (e.g.: GAN)
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||||
return generator_opt, disriminator_opt
|
return generator_opt, disriminator_opt
|
||||||
|
|
||||||
# example with learning_rate schedulers
|
# example with learning rate schedulers
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||||
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
||||||
return [generator_opt, disriminator_opt], [discriminator_sched]
|
return [generator_opt, disriminator_opt], [discriminator_sched]
|
||||||
|
|
||||||
# example with step-based learning_rate schedulers
|
# example with step-based learning rate schedulers
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||||
|
@ -1056,18 +1067,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
- If you use 16-bit precision (``precision=16``), Lightning will automatically
|
- If you use 16-bit precision (``precision=16``), Lightning will automatically
|
||||||
handle the optimizers for you.
|
handle the optimizers for you.
|
||||||
|
|
||||||
- If you use multiple optimizers, training_step will have an additional
|
- If you use multiple optimizers, :meth:`training_step` will have an additional
|
||||||
``optimizer_idx`` parameter.
|
``optimizer_idx`` parameter.
|
||||||
|
|
||||||
- If you use LBFGS lightning handles the closure function automatically for you
|
- If you use LBFGS Lightning handles the closure function automatically for you.
|
||||||
|
|
||||||
- If you use multiple optimizers, gradients will be calculated only
|
- If you use multiple optimizers, gradients will be calculated only
|
||||||
for the parameters of current optimizer at each training step.
|
for the parameters of current optimizer at each training step.
|
||||||
|
|
||||||
- If you need to control how often those optimizers step or override the
|
- If you need to control how often those optimizers step or override the
|
||||||
default .step() schedule, override the `optimizer_step` hook.
|
default ``.step()`` schedule, override the :meth:`optimizer_step` hook.
|
||||||
|
|
||||||
- If you only want to call a learning rate scheduler every `x` step or epoch,
|
- If you only want to call a learning rate scheduler every ``x`` step or epoch,
|
||||||
or want to monitor a custom metric, you can specify these in a dictionary:
|
or want to monitor a custom metric, you can specify these in a dictionary:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1091,16 +1102,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
second_order_closure: Optional[Callable] = None,
|
second_order_closure: Optional[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
Override this method to adjust the default way the
|
||||||
Override this method to adjust the default way the Trainer calls each optimizer.
|
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
|
||||||
By default, Lightning calls .step() and zero_grad() as shown in the example
|
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
|
||||||
once per optimizer.
|
once per optimizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
epoch: Current epoch
|
epoch: Current epoch
|
||||||
batch_idx: Index of current batch
|
batch_idx: Index of current batch
|
||||||
optimizer: A PyTorch optimizer
|
optimizer: A PyTorch optimizer
|
||||||
optimizer_idx: If you used multiple optimizers this indexes into that list
|
optimizer_idx: If you used multiple optimizers this indexes into that list.
|
||||||
second_order_closure: closure for second order methods
|
second_order_closure: closure for second order methods
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -1112,7 +1123,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Alternating schedule for optimizer steps (ie: GANs)
|
# Alternating schedule for optimizer steps (i.e.: GANs)
|
||||||
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
|
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
|
||||||
second_order_closure=None):
|
second_order_closure=None):
|
||||||
# update generator opt every 2 steps
|
# update generator opt every 2 steps
|
||||||
|
@ -1132,7 +1143,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
|
|
||||||
Here's another example showing how to use this for more advanced things such as
|
Here's another example showing how to use this for more advanced things such as
|
||||||
learning-rate warm-up:
|
learning rate warm-up:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -1162,17 +1173,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
|
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
When using truncated backpropagation through time, each batch must be split along the
|
When using truncated backpropagation through time, each batch must be split along the
|
||||||
time dimension. Lightning handles this by default, but for custom behavior override
|
time dimension. Lightning handles this by default, but for custom behavior override
|
||||||
this function.
|
this function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Current batch
|
batch: Current batch
|
||||||
split_size: How big the split is
|
split_size: The size of the split
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
list of batch splits. Each split will be passed to forward_step to enable truncated
|
List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated
|
||||||
back propagation through time. The default implementation splits root level Tensors and
|
back propagation through time. The default implementation splits root level Tensors and
|
||||||
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
||||||
|
|
||||||
|
@ -1197,8 +1207,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
return splits
|
return splits
|
||||||
|
|
||||||
.. note:: Called in the training loop after on_batch_start if ``truncated_bptt_steps > 0``.
|
Note:
|
||||||
Each returned batch split is passed separately to ``training_step(...)``.
|
Called in the training loop after
|
||||||
|
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
|
||||||
|
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||||
|
Each returned batch split is passed separately to :meth:`training_step`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
||||||
|
@ -1223,13 +1236,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
return splits
|
return splits
|
||||||
|
|
||||||
def prepare_data(self) -> None:
|
def prepare_data(self) -> None:
|
||||||
"""Use this to download and prepare data.
|
"""
|
||||||
In distributed (GPU, TPU), this will only be called once
|
Use this to download and prepare data.
|
||||||
|
In distributed (GPU, TPU), this will only be called once.
|
||||||
Return:
|
This is called before requesting the dataloaders:
|
||||||
PyTorch DataLoader
|
|
||||||
|
|
||||||
This is called before requesting the dataloaders
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -1248,23 +1258,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def train_dataloader(self) -> DataLoader:
|
def train_dataloader(self) -> DataLoader:
|
||||||
"""Implement a PyTorch DataLoader
|
"""
|
||||||
|
Implement a PyTorch DataLoader for training.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
PyTorch DataLoader
|
Single PyTorch :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
Return a dataloader. It will not be called every epoch unless you set
|
The dataloader you return will not be called every epoch unless you set
|
||||||
```Trainer(reload_dataloaders_every_epoch=True)```.
|
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
|
||||||
|
|
||||||
It's recommended that all data downloads and preparation happen in prepare_data().
|
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||||
|
|
||||||
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware.
|
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
|
||||||
No need to set yourself.
|
|
||||||
|
|
||||||
- .fit()
|
|
||||||
- ...
|
- ...
|
||||||
- prepare_data()
|
- :meth:`prepare_data`
|
||||||
- train_dataloader
|
- :meth:`train_dataloader`
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Lightning adds the correct sampler for distributed and arbitrary hardware.
|
||||||
|
There is no need to set it yourself.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1285,10 +1297,9 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
|
warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
|
||||||
|
|
||||||
def tng_dataloader(self): # todo: remove in v1.0.0
|
def tng_dataloader(self): # todo: remove in v1.0.0
|
||||||
"""Implement a PyTorch DataLoader.
|
"""
|
||||||
|
|
||||||
Warnings:
|
Warnings:
|
||||||
Deprecated in v0.5.0. use train_dataloader instead. Will be removed 1.0.0
|
Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0.
|
||||||
"""
|
"""
|
||||||
output = self.train_dataloader()
|
output = self.train_dataloader()
|
||||||
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
|
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
|
||||||
|
@ -1297,24 +1308,26 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
r"""
|
r"""
|
||||||
|
Implement one or multiple PyTorch DataLoaders for testing.
|
||||||
|
|
||||||
Return a dataloader. It will not be called every epoch unless you set
|
The dataloader you return will not be called every epoch unless you set
|
||||||
```Trainer(reload_dataloaders_every_epoch=True)```.
|
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
|
||||||
|
|
||||||
It's recommended that all data downloads and preparation happen in prepare_data().
|
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||||
|
|
||||||
- .fit()
|
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
|
||||||
- ...
|
- ...
|
||||||
- prepare_data()
|
- :meth:`prepare_data`
|
||||||
- train_dataloader
|
- :meth:`train_dataloader`
|
||||||
- val_dataloader
|
- :meth:`val_dataloader`
|
||||||
- test_dataloader
|
- :meth:`test_dataloader`
|
||||||
|
|
||||||
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware.
|
Note:
|
||||||
No need to set yourself.
|
Lightning adds the correct sampler for distributed and arbitrary hardware.
|
||||||
|
There is no need to set it yourself.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Single or multiple PyTorch DataLoader
|
Single or multiple PyTorch DataLoaders.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1332,30 +1345,34 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
.. note:: If you don't need a test dataset and a test_step, you don't need to implement
|
Note:
|
||||||
|
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
|
||||||
this method.
|
this method.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
r"""
|
r"""
|
||||||
|
Implement one or multiple PyTorch DataLoaders for validation.
|
||||||
|
|
||||||
Return a dataloader. It will not be called every epoch unless you set
|
The dataloader you return will not be called every epoch unless you set
|
||||||
```Trainer(reload_dataloaders_every_epoch=True)```.
|
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
|
||||||
|
|
||||||
It's recommended that all data downloads and preparation happen in prepare_data().
|
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||||
|
|
||||||
- .fit()
|
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
|
||||||
- ...
|
- ...
|
||||||
- prepare_data()
|
- :meth:`prepare_data`
|
||||||
- train_dataloader
|
- :meth:`train_dataloader`
|
||||||
- val_dataloader
|
- :meth:`val_dataloader`
|
||||||
|
- :meth:`test_dataloader`
|
||||||
|
|
||||||
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware
|
Note:
|
||||||
No need to set yourself.
|
Lightning adds the correct sampler for distributed and arbitrary hardware
|
||||||
|
There is no need to set it yourself.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Single or multiple PyTorch DataLoader
|
Single or multiple PyTorch DataLoaders.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1377,37 +1394,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return [loader_a, loader_b, ..., loader_n]
|
return [loader_a, loader_b, ..., loader_n]
|
||||||
|
|
||||||
.. code-block:: python
|
Note:
|
||||||
|
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
|
||||||
def val_dataloader(self):
|
|
||||||
transform = transforms.Compose([transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.5,), (1.0,))])
|
|
||||||
dataset = MNIST(root='/path/to/mnist/', train=False,
|
|
||||||
transform=transform, download=True)
|
|
||||||
loader = torch.utils.data.DataLoader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=self.hparams.batch_size,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return loader
|
|
||||||
|
|
||||||
# can also return multiple dataloaders
|
|
||||||
def val_dataloader(self):
|
|
||||||
return [loader_a, loader_b, ..., loader_n]
|
|
||||||
|
|
||||||
.. note:: If you don't need a validation dataset and a validation_step, you don't need to
|
|
||||||
implement this method.
|
implement this method.
|
||||||
|
|
||||||
.. note:: In the case where you return multiple `val_dataloaders`, the `validation_step`
|
Note:
|
||||||
will have an argument `dataset_idx` which matches the order here.
|
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
|
||||||
|
will have an argument ``dataset_idx`` which matches the order here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
||||||
r"""
|
r"""
|
||||||
Warning:
|
Warning:
|
||||||
Deprecated in version 0.7.0. You should use `load_from_checkpoint` instead.
|
Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead.
|
||||||
Will be removed in v0.9.0.
|
Will be removed in v0.9.0.
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -1425,11 +1425,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'LightningModule':
|
) -> 'LightningModule':
|
||||||
r"""
|
r"""
|
||||||
|
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
|
||||||
Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
|
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
|
||||||
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
|
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
|
||||||
with an argument called `hparams` which is a Namespace (output of using argparse
|
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
|
||||||
to parse command line arguments).
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1449,8 +1448,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
map_location:
|
map_location:
|
||||||
If your checkpoint saved a GPU model and you now load on CPUs
|
If your checkpoint saved a GPU model and you now load on CPUs
|
||||||
or a different number of GPUs, use this to map to the new setup.
|
or a different number of GPUs, use this to map to the new setup.
|
||||||
The behaviour is the same as in
|
The behaviour is the same as in :func:`torch.load`.
|
||||||
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
|
|
||||||
tags_csv: Optional path to a .csv file with two columns (key, value)
|
tags_csv: Optional path to a .csv file with two columns (key, value)
|
||||||
as in this example::
|
as in this example::
|
||||||
|
|
||||||
|
@ -1462,11 +1460,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
to the checkpoint.
|
to the checkpoint.
|
||||||
However, if your checkpoint weights don't have the hyperparameters saved,
|
However, if your checkpoint weights don't have the hyperparameters saved,
|
||||||
use this method to pass in a .csv file with the hparams you'd like to use.
|
use this method to pass in a .csv file with the hparams you'd like to use.
|
||||||
These will be converted into a argparse.Namespace and passed into your
|
These will be converted into a :class:`~argparse.Namespace` and passed into your
|
||||||
LightningModule for use.
|
:class:`LightningModule` for use.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
LightningModule with loaded weights and hyperparameters (if available).
|
:class:`LightningModule` with loaded weights and hyperparameters (if available).
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1558,7 +1556,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def freeze(self) -> None:
|
def freeze(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Freeze all params for inference
|
Freeze all params for inference.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1573,7 +1571,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
def unfreeze(self) -> None:
|
def unfreeze(self) -> None:
|
||||||
"""Unfreeze all params for training.
|
"""
|
||||||
|
Unfreeze all parameters for training.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -1588,8 +1587,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||||
r"""
|
r"""
|
||||||
Called by lightning to restore your model.
|
Called by Lightning to restore your model.
|
||||||
If you saved something with **on_save_checkpoint** this is your chance to restore this.
|
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint: Loaded checkpoint
|
checkpoint: Loaded checkpoint
|
||||||
|
@ -1602,15 +1601,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
# 99% of the time you don't need to implement this method
|
# 99% of the time you don't need to implement this method
|
||||||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||||
|
|
||||||
.. note:: Lighting auto-restores global step, epoch, and train state including amp scaling.
|
Note:
|
||||||
No need for you to restore anything regarding training.
|
Lightning auto-restores global step, epoch, and train state including amp scaling.
|
||||||
|
There is no need for you to restore anything regarding training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
Called by Lightning when saving a checkpoint to give you a chance to store anything
|
||||||
Called by lightning when saving a checkpoint to give you a chance to store anything
|
else you might want to save.
|
||||||
else you might want to save
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint: Checkpoint to be saved
|
checkpoint: Checkpoint to be saved
|
||||||
|
@ -1618,13 +1617,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
# 99% of use cases you don't need to implement this method
|
# 99% of use cases you don't need to implement this method
|
||||||
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
||||||
|
|
||||||
.. note:: Lighting saves all aspects of training (epoch, global step, etc...)
|
Note:
|
||||||
including amp scaling. No need
|
Lightning saves all aspects of training (epoch, global step, etc...)
|
||||||
for you to store anything about training.
|
including amp scaling.
|
||||||
|
There is no need for you to store anything about training.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue