quick start docs changes (#3028)
* updated code example * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj
This commit is contained in:
parent
ca18e11f6e
commit
5dfc7b157e
|
@ -15,7 +15,7 @@ code to work with Lightning.
|
|||
|
||||
.. raw:: html
|
||||
|
||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video>
|
||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
|
||||
|
||||
|
|
||||
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import random_split
|
||||
|
||||
.. _quick-start:
|
||||
|
||||
|
@ -16,11 +19,11 @@ PyTorch Lightning is nothing more than organized PyTorch code.
|
|||
|
||||
Once you've organized it into a LightningModule, it automates most of the training for you.
|
||||
|
||||
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
|
||||
Here's a 2 minute conversion guide for PyTorch projects:
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video>
|
||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
|
||||
|
||||
----------
|
||||
|
||||
|
@ -34,12 +37,16 @@ A lightningModule defines
|
|||
- Model + system architecture
|
||||
- Optimizer
|
||||
|
||||
.. testcode::
|
||||
:skipif: not TORCHVISION_AVAILABLE
|
||||
|
||||
.. code-block::
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
from torch.utils.data import random_split
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
|
||||
|
@ -74,7 +81,7 @@ well across any accelerator.
|
|||
|
||||
Here's an example of using the Trainer:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block::
|
||||
|
||||
# dataloader
|
||||
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
|
||||
|
@ -83,7 +90,7 @@ Here's an example of using the Trainer:
|
|||
# init model
|
||||
model = LitModel()
|
||||
|
||||
# most basic trainer, uses good defaults
|
||||
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
|
||||
trainer = pl.Trainer()
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
|
@ -350,30 +357,49 @@ And the matching code:
|
|||
|
||||
|
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block::
|
||||
|
||||
class MyDataModule(pl.DataModule):
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
|
||||
def prepare_data(self):
|
||||
# optional to support downloading only once when using multi-GPU or multi-TPU
|
||||
MNIST(os.getcwd(), train=True, download=True)
|
||||
MNIST(os.getcwd(), train=False, download=True)
|
||||
|
||||
def setup(self, stage):
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
if stage == 'fit':
|
||||
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
|
||||
if stage == 'test':
|
||||
mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
|
||||
self.mnist_test = MNIST(os.getcwd(), train=False, download=True)
|
||||
|
||||
def train_dataloader(self):
|
||||
# your train transforms
|
||||
return DataLoader(YOUR_DATASET)
|
||||
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||
return mnist_train
|
||||
|
||||
def val_dataloader(self):
|
||||
# your val transforms
|
||||
return DataLoader(YOUR_DATASET)
|
||||
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
|
||||
return mnist_val
|
||||
|
||||
def test_dataloader(self):
|
||||
# your test transforms
|
||||
return DataLoader(YOUR_DATASET)
|
||||
mnist_test = DataLoader(mnist_test, batch_size=self.batch_size)
|
||||
return mnist_test
|
||||
|
||||
And train like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dm = MyDataModule()
|
||||
dm = MNISTDataModule()
|
||||
trainer.fit(model, dm)
|
||||
|
||||
When doing distributed training, Datamodules have two optional arguments for granular control
|
||||
|
|
Loading…
Reference in New Issue