2022-04-19 18:15:47 +00:00
:orphan:
#################################
Validate and test a model (basic)
#################################
**Audience** : Users who want to add a validation loop to avoid overfitting
----
***** ***** *****
Add a test loop
***** ***** *****
To make sure a model can generalize to an unseen dataset (ie: to publish a paper or in a production environment) a dataset is normally split into two parts, the *train* split and the *test* split.
The test set is **NOT** used during training, it is **ONLY** used once the model has been trained to see how the model will do in the real-world.
----
Find the train and test splits
==============================
Datasets come with two splits. Refer to the dataset documentation to find the *train* and *test* splits.
.. code-block :: python
import torch.utils.data as data
from torchvision import datasets
2022-09-24 10:34:06 +00:00
import torchvision.transforms as transforms
2022-04-19 18:15:47 +00:00
# Load data sets
2022-09-24 10:34:06 +00:00
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
2022-04-19 18:15:47 +00:00
----
Define the test loop
====================
To add a test loop, implement the **test_step** method of the LightningModule
.. code :: python
2023-10-27 18:41:29 +00:00
class LitAutoEncoder(L.LightningModule):
2022-04-19 18:15:47 +00:00
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx):
# this is the test loop
2024-03-15 21:56:28 +00:00
x, _ = batch
2022-04-19 18:15:47 +00:00
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
----
Train with the test loop
========================
Once the model has finished training, call **.test**
.. code-block :: python
from torch.utils.data import DataLoader
# initialize the Trainer
trainer = Trainer()
# test the model
trainer.test(model, dataloaders=DataLoader(test_set))
----
***** ***** ***** ***** *
Add a validation loop
***** ***** ***** ***** *
During training, it's common practice to use a small portion of the train split to determine when the model has finished training.
----
Split the training data
=======================
As a rule of thumb, we use 20% of the training set as the **validation set** . This number varies from dataset to dataset.
.. code-block :: python
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
----
Define the validation loop
==========================
To add a validation loop, implement the **validation_step** method of the LightningModule
.. code :: python
2023-10-27 18:41:29 +00:00
class LitAutoEncoder(L.LightningModule):
2022-04-19 18:15:47 +00:00
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
# this is the validation loop
2024-03-15 21:56:28 +00:00
x, _ = batch
2022-04-19 18:15:47 +00:00
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
2022-09-24 10:34:06 +00:00
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
2022-04-19 18:15:47 +00:00
----
Train with the validation loop
==============================
To run the validation loop, pass in the validation set to **.fit**
.. code-block :: python
from torch.utils.data import DataLoader
2022-09-24 10:34:06 +00:00
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
2023-10-27 10:36:30 +00:00
model = LitAutoEncoder(...)
2022-04-19 18:15:47 +00:00
# train with both splits
2023-10-27 18:41:29 +00:00
trainer = L.Trainer()
2022-09-24 10:34:06 +00:00
trainer.fit(model, train_loader, valid_loader)