updated test (#1073)

This commit is contained in:
William Falcon 2020-03-06 12:12:39 -05:00 committed by J. Borovec
parent ff1f8ef400
commit 9f140b7698
2 changed files with 62 additions and 6 deletions

View File

@ -3,7 +3,7 @@ Child Modules
Research projects tend to test different approaches to the same dataset.
This is very easy to do in Lightning with inheritance.
For example, imaging we now want to train an Autoencoder to use as a feature extractor for MNIST images.
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
Recall that `LitMNIST` already defines all the dataloading etc... The only things
that change in the `Autoencoder` model are the init, forward, training, validation and test step.

View File

@ -212,6 +212,50 @@ Notice the code is exactly the same, except now the training dataloading has bee
under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want
to figure out how they prepare their training data you can just look in the `train_dataloader` method.
Usually though, we want to separate the things that write to disk in data-processing from
things like transforms which happen in memory.
.. code-block:: python
class LitMNIST(pl.LightningModule):
def prepare_data(self):
# download only
MNIST(os.getcwd(), train=True, download=True)
def train_dataloader(self):
# no download, just transform
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=False,
transform=transform)
return DataLoader(mnist_train, batch_size=64)
Doing it in the `prepare_data` method ensures that when you have
multiple GPUs you won't overwrite the data. This is a contrived example
but it gets more complicated with things like NLP or Imagenet.
In general fill these methods with the following:
.. code-block:: python
class LitMNIST(pl.LightningModule):
def prepare_data(self):
# stuff here is done once at the very beginning of training
# before any distributed training starts
# download stuff
# save to disk
# etc...
def train_dataloader(self):
# data transforms
# dataset creation
# return a DataLoader
Optimizer
^^^^^^^^^
@ -606,11 +650,11 @@ metrics we care about, generate samples or add more to our logs.
loss = loss(y_hat, x) # validation_step
outputs.append({'val_loss': loss}) # validation_step
full_loss = outputs.mean() # validation_end
full_loss = outputs.mean() # validation_epoch_end
Since the `validation_step` processes a single batch,
in Lightning we also have a `validation_end` method which allows you to compute
statistics on the full dataset and not just the batch.
in Lightning we also have a `validation_epoch_end` method which allows you to compute
statistics on the full dataset after an epoch of validation data and not just the batch.
In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation.
Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
@ -640,7 +684,7 @@ sample split in the `train_dataloader` method.
return mnist_val
Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which
operates on a single batch and the `validation_end` method to compute statistics on all batches.
operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches.
If you have these methods defined, Lightning will call them automatically. Now we can train
while checking the validation set.
@ -669,7 +713,7 @@ how it will generalize in the "real world." For this, we use a held-out split of
Just like the validation loop, we define exactly the same steps for testing:
- test_step
- test_end
- test_epoch_end
- test_dataloader
.. code-block:: python
@ -707,6 +751,17 @@ Once you train your model simply call `.test()`.
# run test set
trainer.test()
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
--------------------------------------------------------------
TEST RESULTS
{'test_loss': tensor(1.1703, device='cuda:0')}
--------------------------------------------------------------
You can also run the test from a saved lightning model
.. code-block:: python
@ -881,6 +936,7 @@ you could do your own:
Every single part of training is configurable this way.
For a full list look at `lightningModule <lightning-module.rst>`_.
---------
Callbacks
---------