diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index e6d8b33ce8..517fb10409 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -1,6 +1,7 @@ .. testsetup:: * from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.trainer import Trainer .. _introduction-guide: @@ -259,9 +260,9 @@ In this case, it's better to group the full definition of a dataset into a `Data - Val dataloader(s) - Test dataloader(s) -.. code-block:: python +.. testcode:: python - class MyDataModule(pl.DataModule): + class MyDataModule(LightningDataModule): def __init__(self): super().__init__() diff --git a/docs/source/lightning-module.rst b/docs/source/lightning-module.rst index c6bf4453ba..6be28af555 100644 --- a/docs/source/lightning-module.rst +++ b/docs/source/lightning-module.rst @@ -51,7 +51,7 @@ Notice a few things. # or to init a new tensor new_x = torch.Tensor(2, 3) - new_x = new_x.type_as(x.type()) + new_x = new_x.type_as(x) 5. There are no samplers for distributed, Lightning also does this for you. diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst index 68b4f54afa..2ce86bd0cd 100644 --- a/docs/source/new-project.rst +++ b/docs/source/new-project.rst @@ -1,6 +1,7 @@ .. testsetup:: * from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.trainer import Trainer import os import torch @@ -357,9 +358,9 @@ And the matching code: | -.. code-block:: +.. testcode:: python - class MNISTDataModule(pl.LightningDataModule): + class MNISTDataModule(LightningDataModule): def __init__(self, batch_size=32): super().__init__() @@ -407,7 +408,7 @@ over download/prepare/splitting data .. code-block:: python - class MyDataModule(pl.DataModule): + class MyDataModule(LightningDataModule): def prepare_data(self): # called only on 1 GPU @@ -415,12 +416,12 @@ over download/prepare/splitting data tokenize() etc() - def setup(self): + def setup(self, stage=None): # called on every GPU (assigning state is OK) self.train = ... self.val = ... - def train_dataloader(self): + def train_dataloader(self): # do more... return self.train @@ -432,7 +433,7 @@ First, define the information that you might need. .. code-block:: python - class MyDataModule(pl.DataModule): + class MyDataModule(LightningDataModule): def __init__(self): super().__init__() @@ -444,7 +445,7 @@ First, define the information that you might need. tokenize() build_vocab() - def setup(self): + def setup(self, stage=None): vocab = load_vocab self.vocab_size = len(vocab)