lightning/docs/source-pytorch/common/evaluation_basic.rst

132 lines
3.7 KiB
ReStructuredText

:orphan:
#################################
Validate and test a model (basic)
#################################
**Audience**: Users who want to add a validation loop to avoid overfitting
----
***************
Add a test loop
***************
To make sure a model can generalize to an unseen dataset (ie: to publish a paper or in a production environment) a dataset is normally split into two parts, the *train* split and the *test* split.
The test set is **NOT** used during training, it is **ONLY** used once the model has been trained to see how the model will do in the real-world.
----
Find the train and test splits
==============================
Datasets come with two splits. Refer to the dataset documentation to find the *train* and *test* splits.
.. code-block:: python
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
----
Define the test loop
====================
To add a test loop, implement the **test_step** method of the LightningModule
.. code:: python
class LitAutoEncoder(L.LightningModule):
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx):
# this is the test loop
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
----
Train with the test loop
========================
Once the model has finished training, call **.test**
.. code-block:: python
from torch.utils.data import DataLoader
# initialize the Trainer
trainer = Trainer()
# test the model
trainer.test(model, dataloaders=DataLoader(test_set))
----
*********************
Add a validation loop
*********************
During training, it's common practice to use a small portion of the train split to determine when the model has finished training.
----
Split the training data
=======================
As a rule of thumb, we use 20% of the training set as the **validation set**. This number varies from dataset to dataset.
.. code-block:: python
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
----
Define the validation loop
==========================
To add a validation loop, implement the **validation_step** method of the LightningModule
.. code:: python
class LitAutoEncoder(L.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
----
Train with the validation loop
==============================
To run the validation loop, pass in the validation set to **.fit**
.. code-block:: python
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
model = LitAutoEncoder(...)
# train with both splits
trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader)