updated test (#1073)
This commit is contained in:
parent
ff1f8ef400
commit
9f140b7698
|
@ -3,7 +3,7 @@ Child Modules
|
|||
Research projects tend to test different approaches to the same dataset.
|
||||
This is very easy to do in Lightning with inheritance.
|
||||
|
||||
For example, imaging we now want to train an Autoencoder to use as a feature extractor for MNIST images.
|
||||
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
|
||||
Recall that `LitMNIST` already defines all the dataloading etc... The only things
|
||||
that change in the `Autoencoder` model are the init, forward, training, validation and test step.
|
||||
|
||||
|
|
|
@ -212,6 +212,50 @@ Notice the code is exactly the same, except now the training dataloading has bee
|
|||
under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want
|
||||
to figure out how they prepare their training data you can just look in the `train_dataloader` method.
|
||||
|
||||
Usually though, we want to separate the things that write to disk in data-processing from
|
||||
things like transforms which happen in memory.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitMNIST(pl.LightningModule):
|
||||
|
||||
def prepare_data(self):
|
||||
# download only
|
||||
MNIST(os.getcwd(), train=True, download=True)
|
||||
|
||||
def train_dataloader(self):
|
||||
# no download, just transform
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(os.getcwd(), train=True, download=False,
|
||||
transform=transform)
|
||||
return DataLoader(mnist_train, batch_size=64)
|
||||
|
||||
Doing it in the `prepare_data` method ensures that when you have
|
||||
multiple GPUs you won't overwrite the data. This is a contrived example
|
||||
but it gets more complicated with things like NLP or Imagenet.
|
||||
|
||||
In general fill these methods with the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitMNIST(pl.LightningModule):
|
||||
|
||||
def prepare_data(self):
|
||||
# stuff here is done once at the very beginning of training
|
||||
# before any distributed training starts
|
||||
|
||||
# download stuff
|
||||
# save to disk
|
||||
# etc...
|
||||
|
||||
def train_dataloader(self):
|
||||
# data transforms
|
||||
# dataset creation
|
||||
# return a DataLoader
|
||||
|
||||
|
||||
|
||||
Optimizer
|
||||
^^^^^^^^^
|
||||
|
||||
|
@ -606,11 +650,11 @@ metrics we care about, generate samples or add more to our logs.
|
|||
loss = loss(y_hat, x) # validation_step
|
||||
outputs.append({'val_loss': loss}) # validation_step
|
||||
|
||||
full_loss = outputs.mean() # validation_end
|
||||
full_loss = outputs.mean() # validation_epoch_end
|
||||
|
||||
Since the `validation_step` processes a single batch,
|
||||
in Lightning we also have a `validation_end` method which allows you to compute
|
||||
statistics on the full dataset and not just the batch.
|
||||
in Lightning we also have a `validation_epoch_end` method which allows you to compute
|
||||
statistics on the full dataset after an epoch of validation data and not just the batch.
|
||||
|
||||
In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation.
|
||||
Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
|
||||
|
@ -640,7 +684,7 @@ sample split in the `train_dataloader` method.
|
|||
return mnist_val
|
||||
|
||||
Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which
|
||||
operates on a single batch and the `validation_end` method to compute statistics on all batches.
|
||||
operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches.
|
||||
|
||||
If you have these methods defined, Lightning will call them automatically. Now we can train
|
||||
while checking the validation set.
|
||||
|
@ -669,7 +713,7 @@ how it will generalize in the "real world." For this, we use a held-out split of
|
|||
Just like the validation loop, we define exactly the same steps for testing:
|
||||
|
||||
- test_step
|
||||
- test_end
|
||||
- test_epoch_end
|
||||
- test_dataloader
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -707,6 +751,17 @@ Once you train your model simply call `.test()`.
|
|||
# run test set
|
||||
trainer.test()
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
--------------------------------------------------------------
|
||||
TEST RESULTS
|
||||
{'test_loss': tensor(1.1703, device='cuda:0')}
|
||||
--------------------------------------------------------------
|
||||
|
||||
You can also run the test from a saved lightning model
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -881,6 +936,7 @@ you could do your own:
|
|||
Every single part of training is configurable this way.
|
||||
For a full list look at `lightningModule <lightning-module.rst>`_.
|
||||
|
||||
---------
|
||||
|
||||
Callbacks
|
||||
---------
|
||||
|
|
Loading…
Reference in New Issue