Docs clean-up (#2234)

* update docs

* update docs

* update docs

* update docs

* update docs

* update docs
This commit is contained in:
William Falcon 2020-06-18 08:29:18 -04:00 committed by GitHub
parent a2d3ee80ad
commit 79e1426161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 30 additions and 18 deletions

View File

@ -278,6 +278,9 @@ Doing it in the `prepare_data` method ensures that when you have
multiple GPUs you won't overwrite the data. This is a contrived example
but it gets more complicated with things like NLP or Imagenet.
`prepare_data` gets called on the `LOCAL_RANK=0` GPU per node. If your nodes share a file system,
set `Trainer(prepare_data_per_node=False)` and it will be code from node=0, gpu=0 only.
In general fill these methods with the following:
.. testcode::
@ -535,16 +538,21 @@ will cause all sorts of issues.
To solve this problem, move the download code to the `prepare_data` method in the LightningModule.
In this method we do all the preparation we need to do once (instead of on every gpu).
`prepare_data` can be called in two ways, once per node or only on the root node (`Trainer(prepare_data_per_node=False)`).
.. testcode::
class LitMNIST(LightningModule):
def prepare_data(self):
# download only
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
def setup(self, stage):
# transform
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# download
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
MNIST(os.getcwd(), train=True, download=False, transform=transform)
MNIST(os.getcwd(), train=False, download=False, transform=transform)
# train/val split
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

View File

@ -307,12 +307,11 @@ class TransferLearningModel(pl.LightningModule):
def prepare_data(self):
"""Download images and prepare images datasets."""
# 1. Download the images
download_and_extract_archive(url=DATA_URL,
download_root=self.dl_path,
remove_finished=True)
def setup(self, stage: str):
data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered')
# 2. Load the data + preprocessing & data augmentation

View File

@ -141,10 +141,13 @@ class LightningTemplateModel(LightningModule):
return [optimizer], [scheduler]
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
self.mnist_train = MNIST(self.data_root, train=True, download=True, transform=transform)
self.mnist_test = MNIST(self.data_root, train=False, download=True, transform=transform)
MNIST(self.data_root, train=True, download=True, transform=transforms.ToTensor())
MNIST(self.data_root, train=False, download=True, transform=transforms.ToTensor())
def setup(self, stage):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
self.mnist_train = MNIST(self.data_root, train=True, download=False, transform=transform)
self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform)
def train_dataloader(self):
log.info('Training data loader called.')

View File

@ -267,11 +267,12 @@ allow for this:
>>> class LitModel(pl.LightningModule):
... def prepare_data(self):
... # download
... mnist_train = MNIST(os.getcwd(), train=True, download=True,
... transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=True,
... transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
...
... def setup(self, stage):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...

View File

@ -22,7 +22,7 @@ class ModelHooks(Module):
Called at the beginning of fit and test.
Args:
step: either 'fit' or 'test'
stage: either 'fit' or 'test'
"""
def teardown(self, stage: str):

View File

@ -1290,7 +1290,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
def prepare_data(self) -> None:
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.
In distributed (GPU, TPU), this will only be called once on the local_rank=0 of each node.
To call this on only the root=0 of the main node, use `Trainer(prepare_data_per_node=False)`
This is called before requesting the dataloaders:
.. code-block:: python

View File

@ -104,7 +104,7 @@ class EvalModelTemplate(
return nll
def prepare_data(self):
_ = TrialMNIST(root=self.data_root, train=True, download=True)
TrialMNIST(root=self.data_root, train=True, download=True)
@staticmethod
def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict:

View File

@ -69,7 +69,7 @@ def test_trainer_callback_system(tmpdir):
self.on_test_start_called = False
self.on_test_end_called = False
def setup(self, trainer, step: str):
def setup(self, trainer, stage: str):
assert isinstance(trainer, Trainer)
self.setup_called = True