2020-08-31 15:08:22 +00:00
.. _converting:
2022-02-21 19:47:03 +00:00
######################################
2022-03-01 11:19:44 +00:00
How to Organize PyTorch Into Lightning
2022-02-21 19:47:03 +00:00
######################################
2020-08-30 15:01:16 +00:00
2022-03-01 11:19:44 +00:00
To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2022-04-19 18:15:47 +00:00
***** ***** ***** ***** ***** *****
1. Keep you Computational Code
***** ***** ***** ***** ***** *****
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
Keep your regular nn.Module architecture
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2022-02-21 21:26:18 +00:00
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
2022-04-19 18:15:47 +00:00
class LitModel(nn.Module):
2020-08-30 15:01:16 +00:00
def __init__(self):
super().__init__()
2021-01-26 09:44:54 +00:00
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
2020-08-30 15:01:16 +00:00
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2020-10-11 17:12:19 +00:00
--------
2022-04-19 18:15:47 +00:00
***** ***** ***** ***** ***** **
2. Configure Training Logic
***** ***** ***** ***** ***** **
In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2022-02-21 21:26:18 +00:00
class LitModel(pl.LightningModule):
2022-04-19 18:15:47 +00:00
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
2020-08-30 15:01:16 +00:00
2022-04-19 18:15:47 +00:00
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
loss = F.cross_entropy(y_hat, y)
return loss
2020-10-11 17:12:19 +00:00
2022-04-19 18:15:47 +00:00
.. note :: If you need to fully own the training loop for complicated legacy projects, check out :doc: `Own your loop <../model/own_your_loop>` .
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
----
***** ***** ***** ***** ***** ***** ***** *****
3. Move Optimizer(s) and LR Scheduler(s)
***** ***** ***** ***** ***** ***** ***** *****
Move your optimizers to the :meth: `~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` hook.
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2022-02-21 21:26:18 +00:00
class LitModel(pl.LightningModule):
2022-04-19 18:15:47 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2022-04-19 18:15:47 +00:00
***** ***** ***** ***** ***** ***** ***** *** *
4. Organize Validation Logic (optional)
***** ***** ***** ***** ***** ***** ***** *** *
If you need a validation loop, configure how your validation routine behaves with a batch of validation data:
2020-08-30 15:01:16 +00:00
.. testcode ::
2022-02-21 21:26:18 +00:00
class LitModel(pl.LightningModule):
2020-08-30 15:01:16 +00:00
def validation_step(self, batch, batch_idx):
x, y = batch
2022-04-19 18:15:47 +00:00
y_hat = self.encoder(x)
2020-08-30 15:01:16 +00:00
val_loss = F.cross_entropy(y_hat, y)
2022-02-21 19:47:03 +00:00
self.log("val_loss", val_loss)
.. tip :: `` trainer.validate() `` loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2022-04-19 18:15:47 +00:00
***** ***** ***** ***** ***** ***** ***** *
5. Organize Testing Logic (optional)
***** ***** ***** ***** ***** ***** ***** *
If you need a test loop, configure how your testing routine behaves with a batch of test data:
2020-08-30 15:01:16 +00:00
2020-09-10 21:01:20 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2022-02-21 21:26:18 +00:00
class LitModel(pl.LightningModule):
2020-08-30 15:01:16 +00:00
def test_step(self, batch, batch_idx):
x, y = batch
2022-04-19 18:15:47 +00:00
y_hat = self.encoder(x)
2022-02-21 19:47:03 +00:00
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
--------
2022-04-19 18:15:47 +00:00
***** ***** ***** ***** ***** ***** ***** *****
6. Configure Prediction Logic (optional)
***** ***** ***** ***** ***** ***** ***** *****
If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:
2022-02-21 19:47:03 +00:00
.. testcode ::
class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
2022-04-19 18:15:47 +00:00
pred = self.encoder(x)
2022-02-21 19:47:03 +00:00
return pred
--------
***** ***** ***** ***** ***** ***** ***** ***** **
7. Remove any .cuda() or .to(device) Calls
***** ***** ***** ***** ***** ***** ***** ***** **
Your :doc: `LightningModule <../common/lightning_module>` can automatically run on any hardware!
If you have any explicit calls to `` .cuda() `` or `` .to(device) `` , you can remove them since Lightning makes sure that the data coming from :class: `~torch.utils.data.DataLoader`
and all the :class: `~torch.nn.Module` instances initialized inside `` LightningModule.__init__ `` are moved to the respective devices automatically.
2022-03-07 08:04:21 +00:00
If you still need to access the current device, you can use `` self.device `` anywhere in your `` LightningModule `` except in the `` __init__ `` and `` setup `` methods.
2022-02-21 19:47:03 +00:00
.. testcode ::
class LitModel(LightningModule):
2022-03-07 08:04:21 +00:00
def training_step(self, batch, batch_idx):
z = torch.randn(4, 5, device=self.device)
...
2020-08-30 15:01:16 +00:00
2022-03-07 08:04:21 +00:00
Hint: If you are initializing a :class: `~torch.Tensor` within the `` LightningModule.__init__ `` method and want it to be moved to the device automatically you should call
2022-03-05 23:24:15 +00:00
:meth: `~torch.nn.Module.register_buffer` to register it as a parameter.
2020-08-30 15:01:16 +00:00
2022-02-21 19:47:03 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2022-02-21 19:47:03 +00:00
class LitModel(LightningModule):
2022-03-07 08:04:21 +00:00
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
2020-08-30 15:01:16 +00:00
2020-10-11 17:12:19 +00:00
--------
2022-02-21 19:47:03 +00:00
***** ***** ***** *****
8. Use your own data
***** ***** ***** *****
2022-04-19 18:15:47 +00:00
Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out :doc: `LightningDataModule <../data/datamodule>` .
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
----
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
***** ***** **
Good to know
***** ***** **
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
Additionally, you can run only the validation loop using :meth: `~pytorch_lightning.trainer.trainer.Trainer.validate` method.
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
.. code-block :: python
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
model = LitModel()
trainer.validate(model)
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
.. note :: `` model.eval() `` and `` torch.no_grad() `` are called automatically for validation.
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
The test loop isn't used within :meth: `~pytorch_lightning.trainer.trainer.Trainer.fit` , therefore, you would need to explicitly call :meth: `~pytorch_lightning.trainer.trainer.Trainer.test` .
.. code-block :: python
model = LitModel()
trainer.test(model)
.. note :: `` model.eval() `` and `` torch.no_grad() `` are called automatically for testing.
.. tip :: `` trainer.test() `` loads the best checkpoint automatically by default if checkpointing is enabled.
2022-02-21 19:47:03 +00:00
2022-04-19 18:15:47 +00:00
The predict loop will not be used until you call :meth: `~pytorch_lightning.trainer.trainer.Trainer.predict` .
.. code-block :: python
model = LitModel()
trainer.predict(model)
.. note :: `` model.eval() `` and `` torch.no_grad() `` are called automatically for testing.
.. tip :: `` trainer.predict() `` loads the best checkpoint automatically by default if checkpointing is enabled.