2020-08-31 15:08:22 +00:00
.. testsetup :: *
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.trainer import Trainer
.. _converting:
2020-08-30 15:01:16 +00:00
***** ***** ***** ***** ***** ***** ***** ***
How to organize PyTorch into Lightning
***** ***** ***** ***** ***** ***** ***** ***
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
1. Move your computational code
===============================
Move the model architecture and forward pass to your :class: `~pytorch_lightning.core.LightningModule` .
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
class LitModel(LightningModule):
2020-08-30 15:01:16 +00:00
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
2. Move the optimizer(s) and schedulers
=======================================
2020-10-03 12:15:07 +00:00
Move your optimizers to the :func: `~pytorch_lightning.core.LightningModule.configure_optimizers` hook.
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
class LitModel(LightningModule):
2020-08-30 15:01:16 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
3. Find the train loop "meat"
=============================
2020-10-03 12:15:07 +00:00
Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic.
This should go into the :func: `~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, `` batch `` and `` batch_idx `` in this case):
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
class LitModel(LightningModule):
2020-08-30 15:01:16 +00:00
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
4. Find the val loop "meat"
===========================
2020-10-03 12:15:07 +00:00
To add an (optional) validation loop add logic to the
:func: `~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, `` batch `` and `` batch_idx `` in this case).
2020-08-30 15:01:16 +00:00
.. testcode ::
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return val_loss
2020-10-03 12:15:07 +00:00
.. note :: `` model.eval() `` and `` torch.no_grad() `` are called automatically for validation
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
5. Find the test loop "meat"
============================
2020-10-03 12:15:07 +00:00
To add an (optional) test loop add logic to the
:func: `~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, `` batch `` and `` batch_idx `` in this case).
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
class LitModel(LightningModule):
2020-08-30 15:01:16 +00:00
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
2020-10-03 12:15:07 +00:00
.. note :: `` model.eval() `` and `` torch.no_grad() `` are called automatically for testing.
2020-08-30 15:01:16 +00:00
The test loop will not be used until you call.
.. code-block ::
trainer.test()
2020-10-03 12:15:07 +00:00
.. tip :: .test() loads the best checkpoint automatically
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2020-08-30 15:01:16 +00:00
6. Remove any .cuda() or to.device() calls
==========================================
Your :class: `~pytorch_lightning.core.LightningModule` can automatically run on any hardware!