diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index 5f0d4574e3..e6d8b33ce8 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -15,7 +15,7 @@ code to work with Lightning. .. raw:: html - + | diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst index fb0c55f6d3..68b4f54afa 100644 --- a/docs/source/new-project.rst +++ b/docs/source/new-project.rst @@ -6,6 +6,9 @@ import torch from torch.nn import functional as F from torch.utils.data import DataLoader + from torch.utils.data import DataLoader + import pytorch_lightning as pl + from torch.utils.data import random_split .. _quick-start: @@ -16,11 +19,11 @@ PyTorch Lightning is nothing more than organized PyTorch code. Once you've organized it into a LightningModule, it automates most of the training for you. -To illustrate, here's the typical PyTorch project structure organized in a LightningModule. +Here's a 2 minute conversion guide for PyTorch projects: .. raw:: html - + ---------- @@ -34,12 +37,16 @@ A lightningModule defines - Model + system architecture - Optimizer -.. testcode:: - :skipif: not TORCHVISION_AVAILABLE - +.. code-block:: + import os + import torch + import torch.nn.functional as F + from torchvision.datasets import MNIST + from torchvision import transforms + from torch.utils.data import DataLoader import pytorch_lightning as pl - from pytorch_lightning.metrics.functional import accuracy + from torch.utils.data import random_split class LitModel(pl.LightningModule): @@ -74,7 +81,7 @@ well across any accelerator. Here's an example of using the Trainer: -.. code-block:: python +.. code-block:: # dataloader dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) @@ -83,7 +90,7 @@ Here's an example of using the Trainer: # init model model = LitModel() - # most basic trainer, uses good defaults + # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) trainer = pl.Trainer() trainer.fit(model, train_loader) @@ -350,30 +357,49 @@ And the matching code: | -.. code-block:: python +.. code-block:: - class MyDataModule(pl.DataModule): + class MNISTDataModule(pl.LightningDataModule): - def __init__(self): - ... + def __init__(self, batch_size=32): + super().__init__() + self.batch_size = batch_size + + def prepare_data(self): + # optional to support downloading only once when using multi-GPU or multi-TPU + MNIST(os.getcwd(), train=True, download=True) + MNIST(os.getcwd(), train=False, download=True) + + def setup(self, stage): + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + if stage == 'fit': + mnist_train = MNIST(os.getcwd(), train=True, transform=transform) + self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) + if stage == 'test': + mnist_test = MNIST(os.getcwd(), train=False, transform=transform) + self.mnist_test = MNIST(os.getcwd(), train=False, download=True) def train_dataloader(self): - # your train transforms - return DataLoader(YOUR_DATASET) + mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) + return mnist_train def val_dataloader(self): - # your val transforms - return DataLoader(YOUR_DATASET) + mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size) + return mnist_val def test_dataloader(self): - # your test transforms - return DataLoader(YOUR_DATASET) + mnist_test = DataLoader(mnist_test, batch_size=self.batch_size) + return mnist_test And train like so: .. code-block:: python - dm = MyDataModule() + dm = MNISTDataModule() trainer.fit(model, dm) When doing distributed training, Datamodules have two optional arguments for granular control