.. testsetup:: * from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.trainer import Trainer Quick Start =========== PyTorch Lightning is nothing more than organized PyTorch code. Once you've organized it into a LightningModule, it automates most of the training for you. To illustrate, here's the typical PyTorch project structure organized in a LightningModule. .. figure:: /_images/mnist_imgs/pt_to_pl.jpg :alt: Convert from PyTorch to Lightning Step 1: Define a LightningModule --------------------------------- .. testcode:: :skipif: not TORCHVISION_AVAILABLE import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms from pytorch_lightning.core.lightning import LightningModule class LitModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) def train_dataloader(self): dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True) return loader Step 2: Fit with a Trainer -------------------------- .. testcode:: :skipif: torch.cuda.device_count() < 8 from pytorch_lightning import Trainer model = LitModel() # most basic trainer, uses good defaults trainer = Trainer(gpus=8, num_nodes=1) trainer.fit(model) Under the hood, lightning does (in high-level pseudocode): .. code-block:: python model = LitModel() train_dataloader = model.train_dataloader() optimizer = model.configure_optimizers() for epoch in epochs: train_outs = [] for batch in train_dataloader: loss = model.training_step(batch) loss.backward() train_outs.append(loss.detach()) optimizer.step() optimizer.zero_grad() # optional for logging, etc... model.training_epoch_end(train_outs) Validation loop --------------- To also add a validation loop add the following functions .. testcode:: class LitModel(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) return {'val_loss': F.cross_entropy(y_hat, y)} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} return {'val_loss': avg_loss, 'log': tensorboard_logs} def val_dataloader(self): # TODO: do a real train/val split dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) loader = DataLoader(dataset, batch_size=32, num_workers=4) return loader And now the trainer will call the validation loop automatically .. code-block:: python # most basic trainer, uses good defaults trainer = Trainer(gpus=8, num_nodes=1) trainer.fit(model) Under the hood in pseudocode, lightning does the following: .. testsetup:: * train_dataloader = [] .. testcode:: # ... for batch in train_dataloader: loss = model.training_step() loss.backward() # ... if validate_at_some_point: model.eval() val_outs = [] for val_batch in model.val_dataloader: val_out = model.validation_step(val_batch) val_outs.append(val_out) model.validation_epoch_end(val_outs) model.train() The beauty of Lightning is that it handles the details of when to validate, when to call .eval(), turning off gradients, detaching graphs, making sure you don't enable shuffle for val, etc... .. note:: Lightning removes all the million details you need to remember during research Test loop --------- You might also need a test loop .. testcode:: class LitModel(LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) return {'test_loss': F.cross_entropy(y_hat, y)} def test_epoch_end(self, outputs): avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'test_loss': avg_loss} return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} def test_dataloader(self): # TODO: do a real train/val split dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) loader = DataLoader(dataset, batch_size=32, num_workers=4) return loader However, this time you need to specifically call test (this is done so you don't use the test set by mistake) .. code-block:: python # OPTION 1: # test after fit trainer.fit(model) trainer.test() # OPTION 2: # test after loading weights model = LitModel.load_from_checkpoint(PATH) trainer = Trainer(tpu_cores=1) trainer.test() Again, under the hood, lightning does the following in (pseudocode): .. code-block:: python model.eval() test_outs = [] for test_batch in model.test_dataloader: test_out = model.test_step(val_batch) test_outs.append(test_out) model.test_epoch_end(test_outs) Datasets -------- If you don't want to define the datasets as part of the LightningModule, just pass them into fit instead. .. code-block:: python # pass in datasets if you want. train_dataloader = DataLoader(dataset, batch_size=32, num_workers=4) val_dataloader, test_dataloader = ... trainer = Trainer(gpus=8, num_nodes=1) trainer.fit(model, train_dataloader, val_dataloader) trainer.test(test_dataloader=test_dataloader) The advantage of this method is the ability to reuse models for different datasets. The disadvantage is that for research it makes readability and reproducibility more difficult. This is why we recommend to define the datasets in the LightningModule if you're doing research, but use the method above for production models or for prediction tasks. Why do you need Lightning? -------------------------- Notice the code above has nothing about .cuda() or 16-bit or early stopping or logging, etc... This is where Lightning adds a ton of value. Without changing a SINGLE line of your code, you can now do the following with the above code .. code-block:: python # train on TPUs using 16 bit precision with early stopping # using only half the training data and checking validation every quarter of a training epoch trainer = Trainer( tpu_cores=8, precision=16, early_stop_checkpoint=True, train_percent_check=0.5, val_check_interval=0.25 ) # train on 256 GPUs trainer = Trainer( gpus=8, num_nodes=32 ) # train on 1024 CPUs across 128 machines trainer = Trainer( num_processes=8, num_nodes=128 ) And the best part is that your code is STILL just PyTorch... meaning you can do anything you would normally do. .. code-block:: python model = LitModel() model.eval() y_hat = model(x) model.anything_you_can_do_with_pytorch() Summary ------- In short, by refactoring your PyTorch code: 1. You STILL keep pure PyTorch. 2. You DON't lose any flexibility. 3. You can get rid of all of your boilerplate. 4. You make your code generalizable to any hardware. 5. Your code is now readable and easier to reproduce (ie: you help with the reproducibility crisis). 6. Your LightningModule is still just a pure PyTorch module.