2019-10-31 10:45:28 +00:00
|
|
|
import collections
|
2020-02-25 15:36:44 +00:00
|
|
|
import inspect
|
2020-02-01 20:47:58 +00:00
|
|
|
import logging as log
|
2020-01-20 19:50:31 +00:00
|
|
|
import os
|
|
|
|
import warnings
|
2020-01-14 03:20:38 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2019-10-23 08:48:24 +00:00
|
|
|
from argparse import Namespace
|
2020-03-04 14:33:39 +00:00
|
|
|
from typing import Any, Callable, Dict, Optional, Union
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2019-11-05 15:01:52 +00:00
|
|
|
import torch.distributed as dist
|
2020-03-02 03:15:55 +00:00
|
|
|
from torch.optim import Adam
|
2020-01-29 19:52:23 +00:00
|
|
|
|
2019-11-27 03:39:18 +00:00
|
|
|
from pytorch_lightning.core.decorators import data_loader
|
|
|
|
from pytorch_lightning.core.grads import GradInformation
|
|
|
|
from pytorch_lightning.core.hooks import ModelHooks
|
2020-02-25 18:06:24 +00:00
|
|
|
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
|
2020-01-21 20:18:32 +00:00
|
|
|
from pytorch_lightning.core.memory import ModelSummary
|
2019-11-27 03:39:18 +00:00
|
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
2020-02-25 15:36:44 +00:00
|
|
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
2019-10-21 06:16:55 +00:00
|
|
|
|
2020-02-25 03:23:25 +00:00
|
|
|
try:
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
XLA_AVAILABLE = False
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-01-14 03:20:38 +00:00
|
|
|
class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
2020-02-09 22:39:10 +00:00
|
|
|
|
2019-07-25 16:08:00 +00:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(LightningModule, self).__init__(*args, **kwargs)
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
#: Current dtype
|
2019-03-31 01:45:16 +00:00
|
|
|
self.dtype = torch.FloatTensor
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.exp_save_path = None
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
#: The current epoch
|
2019-03-31 01:45:16 +00:00
|
|
|
self.current_epoch = 0
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
#: Total training batches seen across all epochs
|
2019-03-31 01:45:16 +00:00
|
|
|
self.global_step = 0
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.loaded_optimizer_states_dict = {}
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: Pointer to the trainer object
|
2019-04-23 11:25:09 +00:00
|
|
|
self.trainer = None
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: Pointer to the logger object
|
2019-10-04 22:53:38 +00:00
|
|
|
self.logger = None
|
2019-07-24 20:19:19 +00:00
|
|
|
self.example_input_array = None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
#: True if your model is currently running on GPUs.
|
|
|
|
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
|
2019-03-31 20:29:50 +00:00
|
|
|
self.on_gpu = False
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: True if using dp
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_dp = False
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: True if using ddp
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_ddp = False
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: True if using ddp2
|
2019-10-04 22:53:38 +00:00
|
|
|
self.use_ddp2 = False
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: True if using amp
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_amp = False
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2020-03-05 04:02:19 +00:00
|
|
|
self.hparams = None
|
2020-03-04 14:33:39 +00:00
|
|
|
|
2020-02-25 03:30:53 +00:00
|
|
|
def print(self, *args, **kwargs):
|
|
|
|
r"""
|
|
|
|
Prints only from process 0. Use this in any distributed mode to log only once
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (object): The thing to print
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
2020-02-25 03:30:53 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# example if we were using this model as a feature extractor
|
|
|
|
def forward(self, x):
|
|
|
|
self.print(x, 'in loader')
|
|
|
|
|
|
|
|
"""
|
|
|
|
if self.trainer.proc_rank == 0:
|
2020-02-25 18:06:48 +00:00
|
|
|
log.info(*args, **kwargs)
|
2020-02-25 03:30:53 +00:00
|
|
|
|
2020-01-14 03:20:38 +00:00
|
|
|
@abstractmethod
|
2019-03-31 01:45:16 +00:00
|
|
|
def forward(self, *args, **kwargs):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Same as torch.nn.Module.forward(), however in Lightning you want this to define
|
|
|
|
the operations you want to use for prediction (ie: on a server or as a feature extractor).
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Normally you'd call self.forward() from your training_step() method.
|
|
|
|
This makes it easy to write a complex system for training with the outputs
|
|
|
|
you'd want in a prediction setting.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
x (tensor): Whatever you decide to define in the forward method
|
|
|
|
|
|
|
|
Return:
|
|
|
|
Predicted output
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# example if we were using this model as a feature extractor
|
|
|
|
def forward(self, x):
|
|
|
|
feature_maps = self.convnet(x)
|
|
|
|
return feature_maps
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
feature_maps = self.forward(x)
|
|
|
|
logits = self.classifier(feature_maps)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# ...
|
|
|
|
return loss
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# splitting it this way allows model to be used a feature extractor
|
|
|
|
model = MyModelAbove()
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
inputs = server.get_request()
|
|
|
|
results = model(inputs)
|
|
|
|
server.write_results(results)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# -------------
|
|
|
|
# This is in stark contrast to torch.nn.Module where normally you would have this:
|
|
|
|
def forward(self, batch):
|
|
|
|
x, y = batch
|
|
|
|
feature_maps = self.convnet(x)
|
|
|
|
logits = self.classifier(feature_maps)
|
|
|
|
return logits
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2019-08-13 15:37:37 +00:00
|
|
|
def training_step(self, *args, **kwargs):
|
2020-02-11 04:55:22 +00:00
|
|
|
r"""return loss, dict with metrics for tqdm
|
|
|
|
|
|
|
|
Args:
|
2020-03-05 23:52:17 +00:00
|
|
|
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
|
|
|
dataloader. A tensor, tuple or list
|
2020-02-11 04:55:22 +00:00
|
|
|
batch_idx (int): Integer displaying index of this batch
|
|
|
|
optimizer_idx (int): If using multiple optimizers, this argument will also be present.
|
2020-03-05 23:52:17 +00:00
|
|
|
hiddens(:`Tensor <https://pytorch.org/docs/stable/tensors.html>`_):
|
|
|
|
Passed in if truncated_bptt_steps > 0.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
dict with loss key and optional log, progress keys
|
|
|
|
if implementing training_step, return whatever you need in that step:
|
2020-02-09 22:39:10 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
- loss -> tensor scalar [REQUIRED]
|
|
|
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
2020-03-05 23:52:17 +00:00
|
|
|
You can also do fancier things like multiple forward passes or something model specific.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y, z = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# implement your own
|
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.loss(out, x)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# if using TestTubeLogger or TensorBoardLogger you can nest scalars
|
|
|
|
logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
output = {
|
|
|
|
'loss': loss, # required
|
|
|
|
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
|
|
|
|
'log': logger_logs
|
|
|
|
}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# return a dict
|
|
|
|
return output
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
If you define multiple optimizers, this step will be called with an additional
|
|
|
|
`optimizer_idx` param.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# Multiple optimizers (ie: GANs)
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
if optimizer_idx == 0:
|
|
|
|
# do training_step with encoder
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
# do training_step with decoder
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
If you add truncated back propagation through time you will also get an additional
|
|
|
|
argument with the hidden states of the previous step.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# Truncated back-propagation through time
|
|
|
|
def training_step(self, batch, batch_idx, hiddens):
|
|
|
|
# hiddens are the hiddens from the previous truncated backprop step
|
|
|
|
...
|
|
|
|
out, hiddens = self.lstm(data, hiddens)
|
|
|
|
...
|
2020-02-01 20:51:42 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
return {
|
|
|
|
"loss": ...,
|
|
|
|
"hiddens": hiddens # remember to detach() this
|
|
|
|
}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
You can also return a -1 instead of a dict to stop the current loop. This is useful
|
|
|
|
if you want to break out of the current training epoch early.
|
2019-08-13 15:37:37 +00:00
|
|
|
"""
|
|
|
|
|
2019-11-05 15:01:52 +00:00
|
|
|
def training_end(self, *args, **kwargs):
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Warnings:
|
|
|
|
Deprecated in v0.7.0. use training_step_end instead
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def training_step_end(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Use this when training with dp or ddp2 because training_step will operate
|
|
|
|
on only part of the batch. However, this is still optional
|
|
|
|
and only needed for things like softmax or NCE loss.
|
|
|
|
|
|
|
|
.. note:: If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
|
|
|
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
|
|
|
training_step_end(batch_parts_outputs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
|
|
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
dictionary with loss key and optional log, progress keys:
|
|
|
|
- loss -> tensor scalar [REQUIRED]
|
|
|
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
In this case you should define training_step_end to perform those calculations.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# WITHOUT training_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# --------------
|
|
|
|
# with training_step_end to do softmax over the full batch
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
return {'out': out}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def training_step_end(self, outputs):
|
|
|
|
# this out is now the full size of the batch
|
|
|
|
out = outputs['out']
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# this softmax now uses the full batch size
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. seealso:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
2019-11-05 15:01:52 +00:00
|
|
|
"""
|
|
|
|
|
2019-08-13 15:37:37 +00:00
|
|
|
def validation_step(self, *args, **kwargs):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 17:32:45 +00:00
|
|
|
Operate on a single batch of data from the validation set
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
|
|
|
out = validation_step(train_batch)
|
|
|
|
val_outs.append(out
|
2020-03-05 23:52:17 +00:00
|
|
|
validation_epoch_end(val_outs)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-03-05 23:52:17 +00:00
|
|
|
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
|
|
|
dataloader. A tensor, tuple or list
|
2020-01-17 11:03:31 +00:00
|
|
|
batch_idx (int): The index of this batch
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
|
|
|
(only if multiple val datasets used)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
2020-03-05 17:32:45 +00:00
|
|
|
Dict or OrderedDict - passed to the validation_epoch_end
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one val dataloader:
|
2019-12-04 11:57:10 +00:00
|
|
|
def validation_step(self, batch, batch_idx)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
# if you have multiple val dataloaders:
|
2020-03-05 17:32:45 +00:00
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
EExamples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# CASE 1: A single validation dataset
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# implement your own
|
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.loss(out, y)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# all optional...
|
|
|
|
# return whatever you need for the collation function validation_end
|
|
|
|
output = OrderedDict({
|
|
|
|
'val_loss': loss_val,
|
|
|
|
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
|
|
|
})
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# return an optional dict
|
|
|
|
return output
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
If you pass in multiple val datasets, validation_step will have an additional argument.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# CASE 2: multiple validation datasets
|
|
|
|
def validation_step(self, batch, batch_idx, dataset_idx):
|
|
|
|
# dataset_idx tells you which dataset this is.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
.. note:: If you don't need to validate you don't need to implement this method.
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: When the validation_step is called, the model has been put in eval mode
|
|
|
|
and PyTorch gradients have been disabled. At the end of validation,
|
|
|
|
model goes back to training mode and gradients are enabled.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
def validation_step_end(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Use this when training with dp or ddp2 because training_step will operate
|
|
|
|
on only part of the batch. However, this is still optional
|
|
|
|
and only needed for things like softmax or NCE loss.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. note:: If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
|
|
|
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
|
|
|
validation_step_end(batch_parts_outputs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
|
|
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
dictionary with loss key and optional log, progress keys:
|
|
|
|
- loss -> tensor scalar [REQUIRED]
|
|
|
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
In this case you should define validation_step_end to perform those calculations.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# WITHOUT validation_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# --------------
|
|
|
|
# with validation_step_end to do softmax over the full batch
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
return {'out': out}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def validation_step_end(self, outputs):
|
|
|
|
# this out is now the full size of the batch
|
|
|
|
out = outputs['out']
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# this softmax now uses the full batch size
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. seealso:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
def validation_end(self, outputs):
|
|
|
|
"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Warnings:
|
|
|
|
Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def validation_epoch_end(self, outputs: list):
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
Called at end of validation epoch with the output of all validation_steps
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# the pseudocode for these calls
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
|
|
|
out = validation_step(train_batch)
|
|
|
|
train_outs.append(out
|
|
|
|
validation_epoch_end(val_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-03-05 23:52:17 +00:00
|
|
|
outputs: List of outputs you defined in validation_step, or if there are multiple
|
|
|
|
dataloaders, a list containing a list of outputs for each dataloader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Return:
|
|
|
|
Dict or OrderedDict (dict): Dict has the following optional keys:
|
2019-11-28 17:48:55 +00:00
|
|
|
progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. note:: If you didn't define a validation_step, this won't be called.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
- The outputs here are strictly for logging or progress bar.
|
|
|
|
- If you don't need to display anything, don't return anything.
|
2020-03-05 23:52:17 +00:00
|
|
|
- If you want to manually set current step, you can specify the 'step' key in the 'log' Dict
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
With a single dataloader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def validation_epoch_end(self, outputs):
|
|
|
|
val_acc_mean = 0
|
|
|
|
for output in outputs:
|
2019-11-28 17:48:55 +00:00
|
|
|
val_acc_mean += output['val_acc']
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
val_acc_mean /= len(outputs)
|
|
|
|
tqdm_dict = {'val_acc': val_acc_mean.item()}
|
|
|
|
|
|
|
|
# show val_loss and val_acc in progress bar but only log val_loss
|
|
|
|
results = {
|
|
|
|
'progress_bar': tqdm_dict,
|
|
|
|
'log': {'val_acc': val_acc_mean.item()}
|
|
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each validation step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
|
|
val_acc_mean = 0
|
|
|
|
i = 0
|
|
|
|
for dataloader_outputs in outputs:
|
|
|
|
for output in dataloader_outputs:
|
|
|
|
val_acc_mean += output['val_acc']
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
val_acc_mean /= i
|
|
|
|
tqdm_dict = {'val_acc': val_acc_mean.item()}
|
|
|
|
|
|
|
|
# show val_loss and val_acc in progress bar but only log val_loss
|
|
|
|
results = {
|
|
|
|
'progress_bar': tqdm_dict,
|
|
|
|
'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch}
|
|
|
|
}
|
|
|
|
return results
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
|
|
r"""
|
|
|
|
Operate on a single batch of data from the test set
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd normally generate examples or calculate anything of interest
|
|
|
|
such as accuracy.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
|
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
|
|
|
out = test_step(train_batch)
|
|
|
|
test_outs.append(out
|
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-03-05 23:52:17 +00:00
|
|
|
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your
|
|
|
|
dataloader. A tensor, tuple or list
|
2020-03-05 17:32:45 +00:00
|
|
|
batch_idx (int): The index of this batch
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
|
|
|
(only if multiple test datasets used)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
Dict or OrderedDict - passed to the test_epoch_end
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one test dataloader:
|
|
|
|
def test_step(self, batch, batch_idx)
|
|
|
|
|
|
|
|
# if you have multiple test dataloaders:
|
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx)
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# CASE 1: A single test dataset
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# implement your own
|
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.loss(out, y)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# all optional...
|
|
|
|
# return whatever you need for the collation function validation_end
|
|
|
|
output = OrderedDict({
|
|
|
|
'val_loss': loss_val,
|
|
|
|
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
|
|
|
})
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# return an optional dict
|
|
|
|
return output
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
If you pass in multiple validation datasets, validation_step will have an additional
|
|
|
|
argument.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# CASE 2: multiple validation datasets
|
|
|
|
def test_step(self, batch, batch_idx, dataset_idx):
|
|
|
|
# dataset_idx tells you which dataset this is.
|
|
|
|
|
|
|
|
.. note:: If you don't need to validate you don't need to implement this method.
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: When the validation_step is called, the model has been put in eval mode and
|
|
|
|
PyTorch gradients have been disabled. At the end of validation, model goes back
|
|
|
|
to training mode and gradients are enabled.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def test_step_end(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Use this when training with dp or ddp2 because training_step will operate
|
|
|
|
on only part of the batch. However, this is still optional
|
|
|
|
and only needed for things like softmax or NCE loss.
|
|
|
|
|
|
|
|
.. note:: If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
|
|
|
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
|
|
|
test_step_end(batch_parts_outputs)
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
|
|
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
dictionary with loss key and optional log, progress keys:
|
|
|
|
- loss -> tensor scalar [REQUIRED]
|
|
|
|
- progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
In this case you should define test_step_end to perform those calculations.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# WITHOUT test_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# --------------
|
|
|
|
# with test_step_end to do softmax over the full batch
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
out = self.forward(x)
|
|
|
|
return {'out': out}
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def test_step_end(self, outputs):
|
|
|
|
# this out is now the full size of the batch
|
|
|
|
out = outputs['out']
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# this softmax now uses the full batch size
|
|
|
|
loss = nce_loss(loss)
|
|
|
|
return {'loss': loss}
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. seealso:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def test_end(self, outputs):
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Warnings:
|
|
|
|
Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
|
|
|
"""
|
|
|
|
Called at end of test epoch with the output of all test_steps
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
|
|
|
out = test_step(test_batch)
|
|
|
|
test_outs.append(out)
|
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
outputs (list): List of outputs you defined in test_step, or if there are multiple
|
|
|
|
dataloaders, a list containing a list of outputs for each dataloader
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
Dict or OrderedDict (dict): Dict has the following optional keys:
|
|
|
|
progress_bar -> Dict for progress bar display. Must have only tensors
|
|
|
|
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
|
|
|
|
|
|
|
.. note:: If you didn't define a test_step, this won't be called.
|
|
|
|
|
|
|
|
- The outputs here are strictly for logging or progress bar.
|
|
|
|
- If you don't need to display anything, don't return anything.
|
2020-03-05 23:52:17 +00:00
|
|
|
- If you want to manually set current step, specify it with the 'step' key in the 'log' Dict
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
With a single dataloader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def test_epoch_end(self, outputs):
|
|
|
|
test_acc_mean = 0
|
|
|
|
for output in outputs:
|
2019-11-28 17:48:55 +00:00
|
|
|
test_acc_mean += output['test_acc']
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
test_acc_mean /= len(outputs)
|
|
|
|
tqdm_dict = {'test_acc': test_acc_mean.item()}
|
|
|
|
|
|
|
|
# show test_loss and test_acc in progress bar but only log test_loss
|
|
|
|
results = {
|
|
|
|
'progress_bar': tqdm_dict,
|
|
|
|
'log': {'test_acc': test_acc_mean.item()}
|
|
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each test step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
|
|
|
test_acc_mean = 0
|
|
|
|
i = 0
|
|
|
|
for dataloader_outputs in outputs:
|
|
|
|
for output in dataloader_outputs:
|
|
|
|
test_acc_mean += output['test_acc']
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
test_acc_mean /= i
|
|
|
|
tqdm_dict = {'test_acc': test_acc_mean.item()}
|
|
|
|
|
|
|
|
# show test_loss and test_acc in progress bar but only log test_loss
|
|
|
|
results = {
|
|
|
|
'progress_bar': tqdm_dict,
|
|
|
|
'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch}
|
|
|
|
}
|
|
|
|
return results
|
2019-08-30 22:56:09 +00:00
|
|
|
"""
|
|
|
|
|
2019-11-05 15:01:52 +00:00
|
|
|
def configure_ddp(self, model, device_ids):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Override to init DDP in your own way or with your own wrapper.
|
|
|
|
The only requirements are that:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
1. On a validation batch the call goes to model.validation_step.
|
|
|
|
2. On a training batch the call goes to model.training_step.
|
|
|
|
3. On a testing batch, the call goes to model.test_step
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-02-27 21:07:51 +00:00
|
|
|
model (:class:`.LightningModule`): the LightningModule currently being optimized
|
2020-01-17 11:03:31 +00:00
|
|
|
device_ids (list): the list of GPU ids
|
|
|
|
|
|
|
|
Return:
|
|
|
|
DDP wrapped model
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# default implementation used in Trainer
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
# Lightning DDP simply routes to test_step, val_step, etc...
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model,
|
|
|
|
device_ids=device_ids,
|
|
|
|
find_unused_parameters=True
|
|
|
|
)
|
|
|
|
return model
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-11-05 15:01:52 +00:00
|
|
|
"""
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model,
|
|
|
|
device_ids=device_ids,
|
|
|
|
find_unused_parameters=True
|
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
def init_ddp_connection(self, proc_rank, world_size):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-05 15:01:52 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Override to define your custom way of setting up a distributed environment.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Lightning's implementation uses env:// init by default and sets the first node as root.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
proc_rank (int): The current process rank within the node.
|
|
|
|
world_size (int): Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes).
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def init_ddp_connection(self):
|
|
|
|
# use slurm job id for the port number
|
|
|
|
# guarantees unique ports across jobs from same grid search
|
|
|
|
try:
|
|
|
|
# use the last 4 numbers in the job id as the id
|
|
|
|
default_port = os.environ['SLURM_JOB_ID']
|
|
|
|
default_port = default_port[-4:]
|
|
|
|
|
|
|
|
# all ports should be in the 10k+ range
|
|
|
|
default_port = int(default_port) + 15000
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
default_port = 12910
|
|
|
|
|
|
|
|
# if user gave a port number, use that one instead
|
|
|
|
try:
|
|
|
|
default_port = os.environ['MASTER_PORT']
|
|
|
|
except Exception:
|
|
|
|
os.environ['MASTER_PORT'] = str(default_port)
|
|
|
|
|
|
|
|
# figure out the root node addr
|
|
|
|
try:
|
|
|
|
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
|
|
|
|
except Exception:
|
|
|
|
root_node = '127.0.0.2'
|
|
|
|
|
|
|
|
root_node = self.trainer.resolve_root_node_address(root_node)
|
|
|
|
os.environ['MASTER_ADDR'] = root_node
|
|
|
|
dist.init_process_group(
|
|
|
|
'nccl',
|
|
|
|
rank=self.proc_rank,
|
|
|
|
world_size=self.world_size
|
|
|
|
)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
"""
|
2019-11-05 15:01:52 +00:00
|
|
|
# use slurm job id for the port number
|
|
|
|
# guarantees unique ports across jobs from same grid search
|
|
|
|
try:
|
|
|
|
# use the last 4 numbers in the job id as the id
|
|
|
|
default_port = os.environ['SLURM_JOB_ID']
|
|
|
|
default_port = default_port[-4:]
|
|
|
|
|
|
|
|
# all ports should be in the 10k+ range
|
|
|
|
default_port = int(default_port) + 15000
|
|
|
|
|
2020-01-14 03:20:38 +00:00
|
|
|
except Exception:
|
2019-11-05 15:01:52 +00:00
|
|
|
default_port = 12910
|
|
|
|
|
|
|
|
# if user gave a port number, use that one instead
|
|
|
|
try:
|
|
|
|
default_port = os.environ['MASTER_PORT']
|
|
|
|
except Exception:
|
|
|
|
os.environ['MASTER_PORT'] = str(default_port)
|
|
|
|
|
|
|
|
# figure out the root node addr
|
|
|
|
try:
|
|
|
|
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
|
|
|
|
except Exception:
|
|
|
|
root_node = '127.0.0.2'
|
|
|
|
|
|
|
|
root_node = self.trainer.resolve_root_node_address(root_node)
|
|
|
|
os.environ['MASTER_ADDR'] = root_node
|
|
|
|
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
|
|
|
|
|
|
|
|
def configure_apex(self, amp, model, optimizers, amp_level):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-05 15:01:52 +00:00
|
|
|
Override to init AMP your own way
|
|
|
|
Must return a model and list of optimizers
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
|
|
|
amp (object): pointer to amp library object
|
2020-02-27 21:07:51 +00:00
|
|
|
model (:class:`.LightningModule`): pointer to current lightningModule
|
2020-01-17 11:03:31 +00:00
|
|
|
optimizers (list): list of optimizers passed in configure_optimizers()
|
|
|
|
amp_level (str): AMP mode chosen ('O1', 'O2', etc...)
|
|
|
|
|
|
|
|
Return:
|
|
|
|
Apex wrapped model and optimizers
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# Default implementation used by Trainer.
|
|
|
|
def configure_apex(self, amp, model, optimizers, amp_level):
|
|
|
|
model, optimizers = amp.initialize(
|
|
|
|
model, optimizers, opt_level=amp_level,
|
|
|
|
)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
return model, optimizers
|
2019-11-05 15:01:52 +00:00
|
|
|
"""
|
|
|
|
model, optimizers = amp.initialize(
|
|
|
|
model, optimizers, opt_level=amp_level,
|
|
|
|
)
|
|
|
|
|
|
|
|
return model, optimizers
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def configure_optimizers(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
|
|
|
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Return: any of these 3 options:
|
2019-11-28 17:48:55 +00:00
|
|
|
- Single optimizer
|
|
|
|
- List or Tuple - List of optimizers
|
2020-03-05 23:52:17 +00:00
|
|
|
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# most cases (default if not defined)
|
|
|
|
def configure_optimizers(self):
|
|
|
|
opt = Adam(self.parameters(), lr=1e-3)
|
|
|
|
return opt
|
|
|
|
|
|
|
|
# multiple optimizer case (eg: GAN)
|
|
|
|
def configure_optimizers(self):
|
|
|
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
return generator_opt, disriminator_opt
|
|
|
|
|
|
|
|
# example with learning_rate schedulers
|
|
|
|
def configure_optimizers(self):
|
|
|
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
|
|
|
return [generator_opt, disriminator_opt], [discriminator_sched]
|
|
|
|
|
|
|
|
# example with step-based learning_rate schedulers
|
|
|
|
def configure_optimizers(self):
|
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
|
|
|
|
'interval': 'step'} # called after each training step
|
|
|
|
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
|
|
|
|
return [gen_opt, dis_opt], [gen_sched, dis_sched]
|
|
|
|
|
|
|
|
.. note:: Lightning calls ``.backward()`` and ``.step()`` on each optimizer
|
|
|
|
and learning rate scheduler as needed.
|
|
|
|
|
|
|
|
.. note:: If you use 16-bit precision (``use_amp=True``), Lightning will automatically
|
2020-01-17 11:03:31 +00:00
|
|
|
handle the optimizers for you.
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: If you use multiple optimizers, training_step will have an additional
|
|
|
|
``optimizer_idx`` parameter.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
.. note:: If you use LBFGS lightning handles the closure function automatically for you
|
|
|
|
|
|
|
|
.. note:: If you use multiple optimizers, gradients will be calculated only
|
|
|
|
for the parameters of current optimizer at each training step.
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: If you need to control how often those optimizers step or override the
|
|
|
|
default .step() schedule, override the `optimizer_step` hook.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: If you only want to call a learning rate scheduler every `x` step or epoch,
|
|
|
|
you can input this as 'frequency' key: dict(scheduler=lr_scheduler,
|
|
|
|
interval='step' or 'epoch', frequency=x)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2020-03-02 03:15:55 +00:00
|
|
|
return Adam(self.parameters(), lr=1e-3)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-12-07 13:50:21 +00:00
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Override this method to adjust the default way the Trainer calls each optimizer.
|
|
|
|
By default, Lightning calls .step() and zero_grad() as shown in the example
|
|
|
|
once per optimizer.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
|
|
|
epoch (int): Current epoch
|
|
|
|
batch_idx (int): Index of current batch
|
|
|
|
optimizer (torch.nn.Optimizer): A PyTorch optimizer
|
|
|
|
optimizer_idx (int): If you used multiple optimizers this indexes into that list
|
|
|
|
second_order_closure (int): closure for second order methods
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# DEFAULT
|
|
|
|
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
second_order_closure=None):
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# Alternating schedule for optimizer steps (ie: GANs)
|
|
|
|
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
second_order_closure=None):
|
|
|
|
# update generator opt every 2 steps
|
|
|
|
if optimizer_idx == 0:
|
|
|
|
if batch_idx % 2 == 0 :
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# update discriminator opt every 4 steps
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
if batch_idx % 4 == 0 :
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# ...
|
|
|
|
# add as many optimizers as you want
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Here's another example showing how to use this for more advanced things such as
|
|
|
|
learning-rate warm-up:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# learning rate warm-up
|
|
|
|
def optimizer_step(self, current_epoch, batch_idx, optimizer,
|
|
|
|
optimizer_idx, second_order_closure=None):
|
|
|
|
# warm up lr
|
|
|
|
if self.trainer.global_step < 500:
|
|
|
|
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
|
|
|
|
for pg in optimizer.param_groups:
|
|
|
|
pg['lr'] = lr_scale * self.hparams.learning_rate
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# update params
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
"""
|
2020-02-25 03:23:25 +00:00
|
|
|
if self.trainer.use_tpu and XLA_AVAILABLE:
|
|
|
|
xm.optimizer_step(optimizer)
|
|
|
|
elif isinstance(optimizer, torch.optim.LBFGS):
|
2019-10-05 14:47:18 +00:00
|
|
|
optimizer.step(second_order_closure)
|
|
|
|
else:
|
|
|
|
optimizer.step()
|
2019-08-13 13:32:45 +00:00
|
|
|
|
|
|
|
# clear gradients
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
def tbptt_split_batch(self, batch, split_size):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
When using truncated backpropagation through time, each batch must be split along the
|
|
|
|
time dimension. Lightning handles this by default, but for custom behavior override
|
|
|
|
this function.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
|
|
|
batch (torch.nn.Tensor): Current batch
|
|
|
|
split_size (int): How big the split is
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Return:
|
|
|
|
list of batch splits. Each split will be passed to forward_step to enable truncated
|
|
|
|
back propagation through time. The default implementation splits root level Tensors and
|
|
|
|
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def tbptt_split_batch(self, batch, split_size):
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_x = x[:, t:t + split_size]
|
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
batch_split.append(split_x)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
splits.append(batch_split)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
return splits
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Called in the training loop after on_batch_start if ``truncated_bptt_steps > 0``.
|
|
|
|
Each returned batch split is passed separately to ``training_step(...)``.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
"""
|
2020-02-01 23:44:05 +00:00
|
|
|
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
2019-10-31 10:45:28 +00:00
|
|
|
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
|
|
|
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
|
|
|
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_x = x[:, t:t + split_size]
|
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
|
|
|
|
|
|
|
batch_split.append(split_x)
|
|
|
|
|
|
|
|
splits.append(batch_split)
|
|
|
|
|
|
|
|
return splits
|
|
|
|
|
2020-02-25 03:23:25 +00:00
|
|
|
def prepare_data(self):
|
|
|
|
"""Use this to download and prepare data.
|
|
|
|
In distributed (GPU, TPU), this will only be called once
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
PyTorch DataLoader
|
2020-02-25 03:23:25 +00:00
|
|
|
|
|
|
|
This is called before requesting the dataloaders
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model.prepare_data()
|
|
|
|
model.train_dataloader()
|
|
|
|
model.val_dataloader()
|
|
|
|
model.test_dataloader()
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
2020-02-25 03:23:25 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def prepare_data(self):
|
|
|
|
download_imagenet()
|
|
|
|
clean_imagenet()
|
|
|
|
cache_imagenet()
|
2020-02-25 03:23:25 +00:00
|
|
|
"""
|
|
|
|
return None
|
|
|
|
|
2019-10-21 06:16:55 +00:00
|
|
|
def train_dataloader(self):
|
2019-11-28 17:48:55 +00:00
|
|
|
"""Implement a PyTorch DataLoader
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
|
|
|
PyTorch DataLoader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-02-25 20:05:56 +00:00
|
|
|
Return a dataloader. It will not be called every epoch unless you set
|
|
|
|
```Trainer(reload_dataloaders_every_epoch=True)```.
|
|
|
|
|
|
|
|
It's recommended that all data downloads and preparation happen in prepare_data().
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware.
|
|
|
|
No need to set yourself.
|
|
|
|
|
|
|
|
- .fit()
|
|
|
|
- ...
|
|
|
|
- prepare_data()
|
|
|
|
- train_dataloader
|
|
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (1.0,))])
|
|
|
|
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
|
|
|
|
download=True)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
batch_size=self.hparams.batch_size,
|
|
|
|
shuffle=True
|
|
|
|
)
|
|
|
|
return loader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-10-21 06:16:55 +00:00
|
|
|
"""
|
2020-02-19 11:00:08 +00:00
|
|
|
return None
|
2019-12-04 11:59:19 +00:00
|
|
|
|
|
|
|
@data_loader
|
2020-03-05 17:32:45 +00:00
|
|
|
def tng_dataloader(self): # todo: remove in v1.0.0
|
2019-12-04 11:59:19 +00:00
|
|
|
"""Implement a PyTorch DataLoader.
|
2020-03-05 23:52:17 +00:00
|
|
|
|
|
|
|
Warnings:
|
|
|
|
Deprecated in v0.5.0. use train_dataloader instead. Will be removed 1.0.0
|
2019-12-04 11:59:19 +00:00
|
|
|
"""
|
2020-01-14 03:20:38 +00:00
|
|
|
output = self.train_dataloader()
|
2020-02-11 12:41:15 +00:00
|
|
|
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
|
2020-03-05 17:32:45 +00:00
|
|
|
" and this method will be removed in v1.0.0", DeprecationWarning)
|
2020-01-14 03:20:38 +00:00
|
|
|
return output
|
2019-10-21 06:16:55 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def test_dataloader(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-02-25 20:05:56 +00:00
|
|
|
Return a dataloader. It will not be called every epoch unless you set
|
|
|
|
```Trainer(reload_dataloaders_every_epoch=True)```.
|
|
|
|
|
|
|
|
It's recommended that all data downloads and preparation happen in prepare_data().
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
- .fit()
|
|
|
|
- ...
|
|
|
|
- prepare_data()
|
|
|
|
- train_dataloader
|
|
|
|
- val_dataloader
|
|
|
|
- test_dataloader
|
2020-02-25 20:05:56 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware.
|
|
|
|
No need to set yourself.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
PyTorch DataLoader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def test_dataloader(self):
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (1.0,))])
|
|
|
|
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
|
|
|
|
download=True)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
batch_size=self.hparams.batch_size,
|
|
|
|
shuffle=True
|
|
|
|
)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
return loader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: If you don't need a test dataset and a test_step, you don't need to implement
|
|
|
|
this method.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: If you want to change the data during every epoch DON'T use the data_loader
|
|
|
|
decorator.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
return None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
def val_dataloader(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-02-25 20:05:56 +00:00
|
|
|
Return a dataloader. It will not be called every epoch unless you set
|
|
|
|
```Trainer(reload_dataloaders_every_epoch=True)```.
|
|
|
|
|
|
|
|
It's recommended that all data downloads and preparation happen in prepare_data().
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
- .fit()
|
|
|
|
- ...
|
|
|
|
- prepare_data()
|
|
|
|
- train_dataloader
|
|
|
|
- val_dataloader
|
2020-02-25 20:05:56 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Lightning adds the correct sampler for distributed and arbitrary hardware
|
|
|
|
No need to set yourself.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Return:
|
|
|
|
PyTorch DataLoader
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def val_dataloader(self):
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (1.0,))])
|
|
|
|
dataset = MNIST(root='/path/to/mnist/', train=False,
|
|
|
|
transform=transform, download=True)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
batch_size=self.hparams.batch_size,
|
|
|
|
shuffle=True
|
|
|
|
)
|
|
|
|
|
|
|
|
return loader
|
|
|
|
|
|
|
|
# can also return multiple dataloaders
|
|
|
|
def val_dataloader(self):
|
|
|
|
return [loader_a, loader_b, ..., loader_n]
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
@pl.data_loader
|
|
|
|
def val_dataloader(self):
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (1.0,))])
|
|
|
|
dataset = MNIST(root='/path/to/mnist/', train=False,
|
|
|
|
transform=transform, download=True)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
batch_size=self.hparams.batch_size,
|
|
|
|
shuffle=True
|
|
|
|
)
|
|
|
|
|
|
|
|
return loader
|
|
|
|
|
|
|
|
# can also return multiple dataloaders
|
|
|
|
@pl.data_loader
|
|
|
|
def val_dataloader(self):
|
|
|
|
return [loader_a, loader_b, ..., loader_n]
|
|
|
|
|
|
|
|
.. note:: If you don't need a validation dataset and a validation_step, you don't need to
|
|
|
|
implement this method.
|
|
|
|
|
|
|
|
.. note:: If you want to change the data during every epoch DON'T use the data_loader
|
|
|
|
decorator.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
.. note:: In the case where you return multiple `val_dataloaders`, the `validation_step`
|
|
|
|
will have an argument `dataset_idx` which matches the order here.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
return None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
@classmethod
|
2019-12-15 04:24:46 +00:00
|
|
|
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-03 02:05:38 +00:00
|
|
|
Warning:
|
2020-03-05 23:52:17 +00:00
|
|
|
Deprecated in version 0.7.0. You should use `load_from_checkpoint` instead.
|
|
|
|
Will be removed in v0.9.0.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2020-03-03 02:05:38 +00:00
|
|
|
warnings.warn(
|
|
|
|
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
|
|
|
|
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
|
|
|
|
)
|
|
|
|
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)
|
2019-07-25 14:39:48 +00:00
|
|
|
|
2019-10-23 08:48:24 +00:00
|
|
|
@classmethod
|
2020-03-03 02:05:38 +00:00
|
|
|
def load_from_checkpoint(
|
|
|
|
cls,
|
|
|
|
checkpoint_path: str,
|
|
|
|
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
|
|
|
tags_csv: Optional[str] = None,
|
|
|
|
) -> 'LightningModule':
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
|
|
|
|
Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
|
2020-03-03 02:05:38 +00:00
|
|
|
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
|
|
|
|
with an argument called `hparams` which is a Namespace (output of using argparse
|
2020-03-03 16:09:58 +00:00
|
|
|
to parse command line arguments).
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
from argparse import Namespace
|
|
|
|
hparams = Namespace(**{'learning_rate': 0.1})
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
model = MyModel(hparams)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
class MyModel(LightningModule):
|
|
|
|
def __init__(self, hparams):
|
|
|
|
self.learning_rate = hparams.learning_rate
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Args:
|
2020-03-03 02:05:38 +00:00
|
|
|
checkpoint_path: Path to checkpoint.
|
|
|
|
map_location:
|
2020-02-23 20:01:08 +00:00
|
|
|
If your checkpoint saved a GPU model and you now load on CPUs
|
2020-01-17 11:03:31 +00:00
|
|
|
or a different number of GPUs, use this to map to the new setup.
|
2020-02-23 20:01:08 +00:00
|
|
|
The behaviour is the same as in
|
|
|
|
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
|
2020-03-03 02:05:38 +00:00
|
|
|
tags_csv: Optional path to a .csv file with two columns (key, value)
|
|
|
|
as in this example::
|
|
|
|
|
|
|
|
key,value
|
|
|
|
drop_prob,0.2
|
|
|
|
batch_size,32
|
|
|
|
|
|
|
|
You most likely won't need this since Lightning will always save the hyperparameters
|
|
|
|
to the checkpoint.
|
|
|
|
However, if your checkpoint weights don't have the hyperparameters saved,
|
|
|
|
use this method to pass in a .csv file with the hparams you'd like to use.
|
|
|
|
These will be converted into a argparse.Namespace and passed into your
|
|
|
|
LightningModule for use.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
2020-02-23 20:01:08 +00:00
|
|
|
LightningModule with loaded weights and hyperparameters (if available).
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# load weights without mapping ...
|
|
|
|
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
|
|
|
|
|
|
|
# or load weights mapping all weights from GPU 1 to GPU 0 ...
|
|
|
|
map_location = {'cuda:1':'cuda:0'}
|
|
|
|
MyLightningModule.load_from_checkpoint(
|
|
|
|
'path/to/checkpoint.ckpt',
|
|
|
|
map_location=map_location
|
|
|
|
)
|
|
|
|
|
|
|
|
# or load weights and hyperparameters from separate files.
|
|
|
|
MyLightningModule.load_from_checkpoint(
|
|
|
|
'path/to/checkpoint.ckpt',
|
|
|
|
tags_csv='/path/to/hparams_file.csv'
|
|
|
|
)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
# predict
|
|
|
|
pretrained_model.eval()
|
|
|
|
pretrained_model.freeze()
|
|
|
|
y_hat = pretrained_model(x)
|
2020-03-03 02:05:38 +00:00
|
|
|
"""
|
2019-12-15 04:24:46 +00:00
|
|
|
if map_location is not None:
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
|
|
else:
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
|
|
|
|
2020-03-03 02:05:38 +00:00
|
|
|
if tags_csv is not None:
|
|
|
|
# add the hparams from csv file to checkpoint
|
|
|
|
hparams = load_hparams_from_tags_csv(tags_csv)
|
|
|
|
hparams.__setattr__('on_gpu', False)
|
|
|
|
checkpoint['hparams'] = vars(hparams)
|
|
|
|
|
2020-02-25 15:36:44 +00:00
|
|
|
model = cls._load_model_state(checkpoint)
|
|
|
|
return model
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _load_model_state(cls, checkpoint):
|
|
|
|
cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters
|
|
|
|
ckpt_hparams = checkpoint.get('hparams')
|
|
|
|
|
|
|
|
if cls_takes_hparams:
|
|
|
|
if ckpt_hparams is not None:
|
2020-03-04 14:33:39 +00:00
|
|
|
is_namespace = checkpoint.get('hparams_type') == 'namespace'
|
|
|
|
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
|
2020-02-25 15:36:44 +00:00
|
|
|
else:
|
|
|
|
warnings.warn(
|
2020-03-05 23:52:17 +00:00
|
|
|
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
|
|
|
|
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
|
2020-02-25 15:36:44 +00:00
|
|
|
" Did you forget to store your model hyperparameters in self.hparams?"
|
|
|
|
)
|
|
|
|
hparams = Namespace()
|
|
|
|
else: # The user's LightningModule does not define a hparams argument
|
|
|
|
if ckpt_hparams is None:
|
|
|
|
hparams = None
|
|
|
|
else:
|
|
|
|
raise MisconfigurationException(
|
2020-03-05 23:52:17 +00:00
|
|
|
f"Checkpoint contains hyperparameters but {cls.__name__}'s __init__ "
|
|
|
|
f"is missing the argument 'hparams'. Are you loading the correct checkpoint?"
|
2020-02-25 15:36:44 +00:00
|
|
|
)
|
2019-10-23 08:48:24 +00:00
|
|
|
|
|
|
|
# load the state_dict on the model automatically
|
2020-02-25 15:36:44 +00:00
|
|
|
model_args = [hparams] if hparams else []
|
|
|
|
model = cls(*model_args)
|
2019-10-23 08:48:24 +00:00
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
|
|
|
|
# give model a chance to load something
|
|
|
|
model.on_load_checkpoint(checkpoint)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
def summarize(self, mode):
|
|
|
|
model_summary = ModelSummary(self, mode=mode)
|
2020-02-01 20:47:58 +00:00
|
|
|
log.info('\n' + model_summary.__str__())
|
2019-07-25 16:01:52 +00:00
|
|
|
|
|
|
|
def freeze(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Freeze all params for inference
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.freeze()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = False
|
|
|
|
|
2019-11-05 14:14:33 +00:00
|
|
|
self.eval()
|
|
|
|
|
2019-07-25 16:01:52 +00:00
|
|
|
def unfreeze(self):
|
2020-03-03 21:42:49 +00:00
|
|
|
"""Unfreeze all params for training.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.unfreeze()
|
|
|
|
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = True
|
2019-11-05 14:14:33 +00:00
|
|
|
|
|
|
|
self.train()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
def on_load_checkpoint(self, checkpoint):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Called by lightning to restore your model.
|
|
|
|
If you saved something with **on_save_checkpoint** this is your chance to restore this.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
checkpoint (dict): Loaded checkpoint
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def on_load_checkpoint(self, checkpoint):
|
|
|
|
# 99% of the time you don't need to implement this method
|
|
|
|
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Lighting auto-restores global step, epoch, and train state including amp scaling.
|
2020-01-17 11:03:31 +00:00
|
|
|
No need for you to restore anything regarding training.
|
2019-11-28 17:48:55 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Called by lightning when saving a checkpoint to give you a chance to store anything
|
|
|
|
else you might want to save
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
|
|
|
checkpoint (dic): Checkpoint to be saved
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
|
# 99% of use cases you don't need to implement this method
|
|
|
|
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. note:: Lighting saves all aspects of training (epoch, global step, etc...)
|
|
|
|
including amp scaling. No need
|
2020-01-17 11:03:31 +00:00
|
|
|
for you to store anything about training.
|
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
"""
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-02-05 11:24:43 +00:00
|
|
|
def get_tqdm_dict(self):
|
|
|
|
r"""
|
|
|
|
Additional items to be displayed in the progress bar.
|
|
|
|
|
|
|
|
Return:
|
|
|
|
Dictionary with the items to be displayed in the progress bar.
|
|
|
|
"""
|
|
|
|
tqdm_dict = {
|
|
|
|
'loss': '{:.3f}'.format(self.trainer.avg_loss)
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.trainer.truncated_bptt_steps is not None:
|
|
|
|
tqdm_dict['split_idx'] = self.trainer.split_idx
|
|
|
|
|
|
|
|
if self.trainer.logger is not None and self.trainer.logger.version is not None:
|
|
|
|
tqdm_dict['v_num'] = self.trainer.logger.version
|
|
|
|
|
|
|
|
return tqdm_dict
|