updated multiple val dataset docs

This commit is contained in:
William Falcon 2019-08-13 11:43:21 -04:00
parent 0d31b9a229
commit acc16565c5
1 changed files with 17 additions and 3 deletions

View File

@ -52,7 +52,7 @@ class CoolModel(pl.LightningModule):
y_hat = self.forward(x) y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)} return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
def validation_step(self, batch, batch_nb, dataloader_i): def validation_step(self, batch, batch_nb):
# OPTIONAL # OPTIONAL
x, y = batch x, y = batch
y_hat = self.forward(x) y_hat = self.forward(x)
@ -215,7 +215,7 @@ This is most likely the same as your training_step. But unlike training step, th
|---|---| |---|---|
| data_batch | The output of your dataloader. A tensor, tuple or list | | data_batch | The output of your dataloader. A tensor, tuple or list |
| batch_nb | Integer displaying which batch this is | | batch_nb | Integer displaying which batch this is |
| dataloader_i | Integer displaying which dataloader this is | | dataloader_i | Integer displaying which dataloader this is (only if multiple val datasets used) |
**Return** **Return**
@ -226,6 +226,7 @@ This is most likely the same as your training_step. But unlike training step, th
**Example** **Example**
``` {.python} ``` {.python}
# CASE 1: A single validation dataset
def validation_step(self, data_batch, batch_nb): def validation_step(self, data_batch, batch_nb):
x, y, z = data_batch x, y, z = data_batch
@ -246,7 +247,17 @@ def validation_step(self, data_batch, batch_nb):
# return an optional dict # return an optional dict
return output return output
``` ```
If you pass in multiple validation datasets, validation_step will have an additional argument.
```python
# CASE 2: multiple validation datasets
def validation_step(self, data_batch, batch_nb, dataset_idx):
# dataset_idx tells you which dataset this is.
```
The ```dataset_idx``` corresponds to the order of datasets returned in ```val_dataloader```.
--- ---
### validation_end ### validation_end
@ -371,6 +382,9 @@ def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n] return [loader_a, loader_b, ..., loader_n]
``` ```
In the case where you return multiple val_dataloaders, the validation_step will have an arguement ```dataset_idx```
which matches the order here.
--- ---
### test_dataloader ### test_dataloader