.. testsetup:: * from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.trainer import Trainer import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data import DataLoader import pytorch_lightning as pl from torch.utils.data import random_split .. _quick-start: Quick Start =========== PyTorch Lightning is nothing more than organized PyTorch code. Once you've organized it into a LightningModule, it automates most of the training for you. Here's a 2 minute conversion guide for PyTorch projects: .. raw:: html ---------- Step 1: Build LightningModule ----------------------------- A lightningModule defines - Train loop - Val loop - Test loop - Model + system architecture - Optimizer .. code-block:: import os import torch import torch.nn.functional as F from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import DataLoader import pytorch_lightning as pl from torch.utils.data import random_split class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0005) ---------- Step 2: Fit with a Trainer -------------------------- The trainer calls each loop at the correct time as needed. It also ensures it all works well across any accelerator. .. raw:: html | Here's an example of using the Trainer: .. code-block:: # dataloader dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) train_loader = DataLoader(dataset) # init model model = LitModel() # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) trainer = pl.Trainer() trainer.fit(model, train_loader) Using GPUs/TPUs ^^^^^^^^^^^^^^^ It's trivial to use GPUs or TPUs in Lightning. There's NO NEED to change your code, simply change the Trainer options. .. code-block:: python # train on 1, 2, 4, n GPUs Trainer(gpus=1) Trainer(gpus=2) Trainer(gpus=8, num_nodes=n) # train on TPUs Trainer(tpu_cores=8) Trainer(tpu_cores=128) # even half precision Trainer(gpus=2, precision=16) The code above gives you the following for free: - Automatic checkpoints - Automatic Tensorboard (or the logger of your choice) - Automatic CPU/GPU/TPU training - Automatic 16-bit precision All of it 100% rigorously tested and benchmarked -------------- Lightning under the hood ^^^^^^^^^^^^^^^^^^^^^^^^ Lightning is designed for state of the art research ideas by researchers and research engineers from top labs. A LightningModule handles advances cases by allowing you to override any critical part of training via hooks that are called on your LightningModule. .. raw:: html ---------------- Training loop under the hood ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This is the training loop pseudocode that lightning does under the hood: .. code-block:: python # init model model = LitModel() # enable training torch.set_grad_enabled(True) model.train() # get data + optimizer train_dataloader = model.train_dataloader() optimizer = model.configure_optimizers() for epoch in epochs: for batch in train_dataloader: # forward (TRAINING_STEP) loss = model.training_step(batch) # backward loss.backward() # apply and clear grads optimizer.step() optimizer.zero_grad() Main take-aways: - Lightning sets .train() and enables gradients when entering the training loop. - Lightning iterates over the epochs automatically. - Lightning iterates the dataloaders automatically. - Training_step gives you full control of the main loop. - .backward(), .step(), .zero_grad() are called for you. BUT, you can override this if you need manual control. ---------- Adding a Validation loop ------------------------ To add an (optional) validation loop add the following function .. testcode:: class LitModel(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) result = pl.EvalResult(checkpoint_on=loss) result.log('val_loss', loss) return result .. note:: EvalResult is a plain Dict, with convenience functions for logging And now the trainer will call the validation loop automatically .. code-block:: python # pass in the val dataloader to the trainer as well trainer.fit( model, train_dataloader, val_dataloader ) Validation loop under the hood ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Under the hood in pseudocode, lightning does the following: .. code-block:: python # ... for batch in train_dataloader: loss = model.training_step() loss.backward() # ... if validate_at_some_point: # disable grads + batchnorm + dropout torch.set_grad_enabled(False) model.eval() val_outs = [] for val_batch in model.val_dataloader: val_out = model.validation_step(val_batch) val_outs.append(val_out) model.validation_epoch_end(val_outs) # enable grads + batchnorm + dropout torch.set_grad_enabled(True) model.train() Lightning automatically: - Enables gradients and sets model to train() in the train loop - Disables gradients and sets model to eval() in val loop - After val loop ends, enables gradients and sets model to train() ------------- Adding a Test loop ------------------ You might also need an optional test loop .. testcode:: class LitModel(LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) result = pl.EvalResult() result.log('test_loss', loss) return result However, this time you need to specifically call test (this is done so you don't use the test set by mistake) .. code-block:: python # OPTION 1: # test after fit trainer.fit(model) trainer.test(test_dataloaders=test_dataloader) # OPTION 2: # test after loading weights model = LitModel.load_from_checkpoint(PATH) trainer = Trainer() trainer.test(test_dataloaders=test_dataloader) Test loop under the hood ^^^^^^^^^^^^^^^^^^^^^^^^ Under the hood, lightning does the following in (pseudocode): .. code-block:: python # disable grads + batchnorm + dropout torch.set_grad_enabled(False) model.eval() test_outs = [] for test_batch in model.test_dataloader: test_out = model.test_step(val_batch) test_outs.append(test_out) model.test_epoch_end(test_outs) # enable grads + batchnorm + dropout torch.set_grad_enabled(True) model.train() --------------- Data ---- Lightning operates on standard PyTorch Dataloaders (of any flavor). Use dataloaders in 3 ways. Data in fit ^^^^^^^^^^^ Pass the dataloaders into `trainer.fit()` .. code-block:: python trainer.fit(model, train_dataloader, val_dataloader) Data in LightningModule ^^^^^^^^^^^^^^^^^^^^^^^ For fast research prototyping, it might be easier to link the model with the dataloaders. .. code-block:: python class LitModel(pl.LightningModule): def train_dataloader(self): # your train transforms return DataLoader(YOUR_DATASET) def val_dataloader(self): # your val transforms return DataLoader(YOUR_DATASET) def test_dataloader(self): # your test transforms return DataLoader(YOUR_DATASET) And fit like so: .. code-block:: python model = LitModel() trainer.fit(model) DataModule ^^^^^^^^^^ A more reusable approach is to define a DataModule which is simply a collection of all 3 data splits but also captures: - download instructions. - processing. - splitting. - etc... Here's an illustration that explains how to refactor your code into reusable DataModules. .. raw:: html | And the matching code: | .. code-block:: class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super().__init__() self.batch_size = batch_size def prepare_data(self): # optional to support downloading only once when using multi-GPU or multi-TPU MNIST(os.getcwd(), train=True, download=True) MNIST(os.getcwd(), train=False, download=True) def setup(self, stage): transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) if stage == 'fit': mnist_train = MNIST(os.getcwd(), train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) if stage == 'test': mnist_test = MNIST(os.getcwd(), train=False, transform=transform) self.mnist_test = MNIST(os.getcwd(), train=False, download=True) def train_dataloader(self): mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) return mnist_train def val_dataloader(self): mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size) return mnist_val def test_dataloader(self): mnist_test = DataLoader(mnist_test, batch_size=self.batch_size) return mnist_test And train like so: .. code-block:: python dm = MNISTDataModule() trainer.fit(model, dm) When doing distributed training, Datamodules have two optional arguments for granular control over download/prepare/splitting data .. code-block:: python class MyDataModule(pl.DataModule): def prepare_data(self): # called only on 1 GPU download() tokenize() etc() def setup(self): # called on every GPU (assigning state is OK) self.train = ... self.val = ... def train_dataloader(self): # do more... return self.train Building models based on Data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Datamodules are the recommended approach when building models based on the data. First, define the information that you might need. .. code-block:: python class MyDataModule(pl.DataModule): def __init__(self): super().__init__() self.train_dims = None self.vocab_size = 0 def prepare_data(self): download_dataset() tokenize() build_vocab() def setup(self): vocab = load_vocab self.vocab_size = len(vocab) self.train, self.val, self.test = load_datasets() self.train_dims = self.train.next_batch.size() def train_dataloader(self): transforms = ... return DataLoader(self.train, transforms) def val_dataloader(self): transforms = ... return DataLoader(self.val, transforms) def test_dataloader(self): transforms = ... return DataLoader(self.test, transforms) Next, materialize the data and build your model .. code-block:: python # build module dm = MyDataModule() dm.prepare_data() dm.setup() # pass in the properties you want model = LitModel(image_width=dm.train_dims[0], vocab_length=dm.vocab_size) # train trainer.fit(model, dm) ----------------- Logging/progress bar -------------------- | .. image:: /_images/mnist_imgs/mnist_tb.png :width: 300 :align: center :alt: Example TB logs | Lightning has built-in logging to any of the supported loggers or progress bar. Log in train loop ^^^^^^^^^^^^^^^^^ To log from the training loop use the `log` method in the `TrainResult`. .. code-block:: python def training_step(self, batch, batch_idx): loss = ... result = pl.TrainResult(minimize=loss) result.log('train_loss', loss) return result The `TrainResult` gives you options for logging on every step and/or at the end of the epoch. It also allows logging to the progress bar. .. code-block:: python # equivalent result.log('train_loss', loss) result.log('train_loss', loss, prog_bar=False, logger=True, on_step=True, on_epoch=False) Then boot up your logger or tensorboard instance to view training logs .. code-block:: bash tensorboard --logdir ./lightning_logs .. warning:: Refreshing the progress bar too frequently in Jupyter notebooks or Colab may freeze your UI. We recommend you set `Trainer(progress_bar_refresh_rate=10)` Log in Val/Test loop ^^^^^^^^^^^^^^^^^^^^ To log from the validation or test loop use the `EvalResult`. .. code-block:: python def validation_step(self, batch, batch_idx): loss = ... result = pl.EvalResult() result.log_dict({'val_loss': loss, 'val_acc': acc}) return result Log to the progress bar ^^^^^^^^^^^^^^^^^^^^^^^ | .. code-block:: shell Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] | In addition to visual logging, you can log to the progress bar by setting `prog_bar` to True .. code-block:: python def training_step(self, batch, batch_idx): loss = ... result = pl.TrainResult(loss) result.log('train_loss', loss, prog_bar=True) ----------------- Advanced loop aggregation ------------------------- For certain train/val/test loops, you may wish to do more than just logging. In this case, you can also implement `__epoch_end` which gives you the output for each step Here's the motivating Pytorch example: .. code-block:: python validation_step_outputs = [] for batch_idx, batch in val_dataloader(): out = validation_step(batch, batch_idx) validation_step_outputs.append(out) validation_epoch_end(validation_step_outputs) And the lightning equivalent .. code-block:: python def validation_step(self, batch, batch_idx): loss = ... predictions = ... result = pl.EvalResult(checkpoint_on=loss) result.log('val_loss', loss) result.predictions = predictions def validation_epoch_end(self, validation_step_outputs): all_val_losses = validation_step_outputs.val_loss all_predictions = validation_step_outputs.predictions Why do you need Lightning? -------------------------- The MAIN teakeaway points are: - Lightning is for professional AI researchers/production teams. - Lightning is organized PyTorch. It is not an abstraction. - You STILL keep pure PyTorch. - You DON't lose any flexibility. - You can get rid of all of your boilerplate. - You make your code generalizable to any hardware. - Your code is now readable and easier to reproduce (ie: you help with the reproducibility crisis). - Your LightningModule is still just a pure PyTorch module. Lightning is for you if ^^^^^^^^^^^^^^^^^^^^^^^ - You're a professional researcher/ml engineer working on non-trivial deep learning. - You already know PyTorch and are not a beginner. - You want to iterate through research much faster. - You want to put models into production much faster. - You need full control of all the details but don't need the boilerplate. - You want to leverage code written by hundreds of AI researchers, research engs and PhDs from the world's top AI labs. - You need GPUs, multi-node training, half-precision and TPUs. - You want research code that is rigorously tested (500+ tests) across CPUs/multi-GPUs/multi-TPUs on every pull-request. Some more cool features ^^^^^^^^^^^^^^^^^^^^^^^ Here are (some) of the other things you can do with lightning: - Automatic checkpointing. - Automatic early stopping. - Automatically overfit your model for a sanity test. - Automatic truncated-back-propagation-through-time. - Automatically scale your batch size. - Automatically attempt to find a good learning rate. - Add arbitrary callbacks - Hit every line of your code once to see if you have bugs (instead of waiting hours to crash on validation ;) - Load checkpoints directly from S3. - Move from CPUs to GPUs or TPUs without code changes. - Profile your code for speed/memory bottlenecks. - Scale to massive compute clusters. - Use multiple dataloaders per train/val/test loop. - Use multiple optimizers to do Reinforcement learning or even GANs. Example: ^^^^^^^^ Without changing a SINGLE line of your code, you can now do the following with the above code .. code-block:: python # train on TPUs using 16 bit precision with early stopping # using only half the training data and checking validation every quarter of a training epoch trainer = Trainer( tpu_cores=8, precision=16, early_stop_callback=True, limit_train_batches=0.5, val_check_interval=0.25 ) # train on 256 GPUs trainer = Trainer( gpus=8, num_nodes=32 ) # train on 1024 CPUs across 128 machines trainer = Trainer( num_processes=8, num_nodes=128 ) And the best part is that your code is STILL just PyTorch... meaning you can do anything you would normally do. .. code-block:: python model = LitModel() model.eval() y_hat = model(x) model.anything_you_can_do_with_pytorch() --------------- Masterclass ----------- You can learn Lightning in-depth by watching our Masterclass. .. image:: _images/general/PTL101_youtube_thumbnail.jpg :width: 500 :align: center :alt: Masterclass :target: https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2