quick start docs changes (#3028)

* updated code example

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj
This commit is contained in:
William Falcon 2020-08-17 23:17:51 -04:00 committed by GitHub
parent ca18e11f6e
commit 5dfc7b157e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 20 deletions

View File

@ -15,7 +15,7 @@ code to work with Lightning.
.. raw:: html .. raw:: html
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video> <video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
| |

View File

@ -6,6 +6,9 @@
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
.. _quick-start: .. _quick-start:
@ -16,11 +19,11 @@ PyTorch Lightning is nothing more than organized PyTorch code.
Once you've organized it into a LightningModule, it automates most of the training for you. Once you've organized it into a LightningModule, it automates most of the training for you.
To illustrate, here's the typical PyTorch project structure organized in a LightningModule. Here's a 2 minute conversion guide for PyTorch projects:
.. raw:: html .. raw:: html
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video> <video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
---------- ----------
@ -34,12 +37,16 @@ A lightningModule defines
- Model + system architecture - Model + system architecture
- Optimizer - Optimizer
.. testcode:: .. code-block::
:skipif: not TORCHVISION_AVAILABLE
import os
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy from torch.utils.data import random_split
class LitModel(pl.LightningModule): class LitModel(pl.LightningModule):
@ -74,7 +81,7 @@ well across any accelerator.
Here's an example of using the Trainer: Here's an example of using the Trainer:
.. code-block:: python .. code-block::
# dataloader # dataloader
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
@ -83,7 +90,7 @@ Here's an example of using the Trainer:
# init model # init model
model = LitModel() model = LitModel()
# most basic trainer, uses good defaults # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
trainer = pl.Trainer() trainer = pl.Trainer()
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
@ -350,30 +357,49 @@ And the matching code:
| |
.. code-block:: python .. code-block::
class MyDataModule(pl.DataModule): class MNISTDataModule(pl.LightningDataModule):
def __init__(self): def __init__(self, batch_size=32):
... super().__init__()
self.batch_size = batch_size
def prepare_data(self):
# optional to support downloading only once when using multi-GPU or multi-TPU
MNIST(os.getcwd(), train=True, download=True)
MNIST(os.getcwd(), train=False, download=True)
def setup(self, stage):
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if stage == 'fit':
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
self.mnist_test = MNIST(os.getcwd(), train=False, download=True)
def train_dataloader(self): def train_dataloader(self):
# your train transforms mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return DataLoader(YOUR_DATASET) return mnist_train
def val_dataloader(self): def val_dataloader(self):
# your val transforms mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return DataLoader(YOUR_DATASET) return mnist_val
def test_dataloader(self): def test_dataloader(self):
# your test transforms mnist_test = DataLoader(mnist_test, batch_size=self.batch_size)
return DataLoader(YOUR_DATASET) return mnist_test
And train like so: And train like so:
.. code-block:: python .. code-block:: python
dm = MyDataModule() dm = MNISTDataModule()
trainer.fit(model, dm) trainer.fit(model, dm)
When doing distributed training, Datamodules have two optional arguments for granular control When doing distributed training, Datamodules have two optional arguments for granular control