lightning/docs/source/new-project.rst

280 lines
8.1 KiB
ReStructuredText

.. testsetup:: *
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
Quick Start
===========
PyTorch Lightning is nothing more than organized PyTorch code.
Once you've organized it into a LightningModule, it automates most of the training for you.
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
.. figure:: /_images/mnist_imgs/pt_to_pl.jpg
:alt: Convert from PyTorch to Lightning
Step 1: Define a LightningModule
---------------------------------
.. testcode::
:skipif: not TORCHVISION_AVAILABLE
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.core.lightning import LightningModule
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
return loader
Step 2: Fit with a Trainer
--------------------------
.. testcode::
:skipif: torch.cuda.device_count() < 8
from pytorch_lightning import Trainer
model = LitModel()
# most basic trainer, uses good defaults
trainer = Trainer(gpus=8, num_nodes=1)
trainer.fit(model)
Under the hood, lightning does (in high-level pseudocode):
.. code-block:: python
model = LitModel()
train_dataloader = model.train_dataloader()
optimizer = model.configure_optimizers()
for epoch in epochs:
train_outs = []
for batch in train_dataloader:
loss = model.training_step(batch)
loss.backward()
train_outs.append(loss.detach())
optimizer.step()
optimizer.zero_grad()
# optional for logging, etc...
model.training_epoch_end(train_outs)
Validation loop
---------------
To also add a validation loop add the following functions
.. testcode::
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def val_dataloader(self):
# TODO: do a real train/val split
dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4)
return loader
And now the trainer will call the validation loop automatically
.. code-block:: python
# most basic trainer, uses good defaults
trainer = Trainer(gpus=8, num_nodes=1)
trainer.fit(model)
Under the hood in pseudocode, lightning does the following:
.. testsetup:: *
train_dataloader = []
.. testcode::
# ...
for batch in train_dataloader:
loss = model.training_step()
loss.backward()
# ...
if validate_at_some_point:
model.eval()
val_outs = []
for val_batch in model.val_dataloader:
val_out = model.validation_step(val_batch)
val_outs.append(val_out)
model.validation_epoch_end(val_outs)
model.train()
The beauty of Lightning is that it handles the details of when to validate, when to call .eval(),
turning off gradients, detaching graphs, making sure you don't enable shuffle for val, etc...
.. note:: Lightning removes all the million details you need to remember during research
Test loop
---------
You might also need a test loop
.. testcode::
class LitModel(LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
return {'test_loss': F.cross_entropy(y_hat, y)}
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_loss': avg_loss}
return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
def test_dataloader(self):
# TODO: do a real train/val split
dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4)
return loader
However, this time you need to specifically call test (this is done so you don't use the test set by mistake)
.. code-block:: python
# OPTION 1:
# test after fit
trainer.fit(model)
trainer.test()
# OPTION 2:
# test after loading weights
model = LitModel.load_from_checkpoint(PATH)
trainer = Trainer(num_tpu_cores=1)
trainer.test()
Again, under the hood, lightning does the following in (pseudocode):
.. code-block:: python
model.eval()
test_outs = []
for test_batch in model.test_dataloader:
test_out = model.test_step(val_batch)
test_outs.append(test_out)
model.test_epoch_end(test_outs)
Datasets
--------
If you don't want to define the datasets as part of the LightningModule, just pass them into fit instead.
.. code-block:: python
# pass in datasets if you want.
train_dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
val_dataloader, test_dataloader = ...
trainer = Trainer(gpus=8, num_nodes=1)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(test_dataloader=test_dataloader)
The advantage of this method is the ability to reuse models for different datasets. The disadvantage
is that for research it makes readability and reproducibility more difficult. This is why we recommend
to define the datasets in the LightningModule if you're doing research, but use the method above for
production models or for prediction tasks.
Why do you need Lightning?
--------------------------
Notice the code above has nothing about .cuda() or 16-bit or early stopping or logging, etc...
This is where Lightning adds a ton of value.
Without changing a SINGLE line of your code, you can now do the following with the above code
.. code-block:: python
# train on TPUs using 16 bit precision with early stopping
# using only half the training data and checking validation every quarter of a training epoch
trainer = Trainer(
nb_tpu_cores=8,
precision=16,
early_stop_checkpoint=True,
train_percent_check=0.5,
val_check_interval=0.25
)
# train on 256 GPUs
trainer = Trainer(
gpus=8,
num_nodes=32
)
# train on 1024 CPUs across 128 machines
trainer = Trainer(
num_processes=8,
num_nodes=128
)
And the best part is that your code is STILL just PyTorch... meaning you can do anything you
would normally do.
.. code-block:: python
model = LitModel()
model.eval()
y_hat = model(x)
model.anything_you_can_do_with_pytorch()
Summary
-------
In short, by refactoring your PyTorch code:
1. You STILL keep pure PyTorch.
2. You DON't lose any flexibility.
3. You can get rid of all of your boilerplate.
4. You make your code generalizable to any hardware.
5. Your code is now readable and easier to reproduce (ie: you help with the reproducibility crisis).
6. Your LightningModule is still just a pure PyTorch module.