updated multiple val dataset docs
This commit is contained in:
parent
0d31b9a229
commit
acc16565c5
|
@ -52,7 +52,7 @@ class CoolModel(pl.LightningModule):
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
|
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_nb, dataloader_i):
|
def validation_step(self, batch, batch_nb):
|
||||||
# OPTIONAL
|
# OPTIONAL
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
|
@ -215,7 +215,7 @@ This is most likely the same as your training_step. But unlike training step, th
|
||||||
|---|---|
|
|---|---|
|
||||||
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
||||||
| batch_nb | Integer displaying which batch this is |
|
| batch_nb | Integer displaying which batch this is |
|
||||||
| dataloader_i | Integer displaying which dataloader this is |
|
| dataloader_i | Integer displaying which dataloader this is (only if multiple val datasets used) |
|
||||||
|
|
||||||
**Return**
|
**Return**
|
||||||
|
|
||||||
|
@ -226,6 +226,7 @@ This is most likely the same as your training_step. But unlike training step, th
|
||||||
**Example**
|
**Example**
|
||||||
|
|
||||||
``` {.python}
|
``` {.python}
|
||||||
|
# CASE 1: A single validation dataset
|
||||||
def validation_step(self, data_batch, batch_nb):
|
def validation_step(self, data_batch, batch_nb):
|
||||||
x, y, z = data_batch
|
x, y, z = data_batch
|
||||||
|
|
||||||
|
@ -246,7 +247,17 @@ def validation_step(self, data_batch, batch_nb):
|
||||||
|
|
||||||
# return an optional dict
|
# return an optional dict
|
||||||
return output
|
return output
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you pass in multiple validation datasets, validation_step will have an additional argument.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CASE 2: multiple validation datasets
|
||||||
|
def validation_step(self, data_batch, batch_nb, dataset_idx):
|
||||||
|
# dataset_idx tells you which dataset this is.
|
||||||
|
```
|
||||||
|
|
||||||
|
The ```dataset_idx``` corresponds to the order of datasets returned in ```val_dataloader```.
|
||||||
|
|
||||||
---
|
---
|
||||||
### validation_end
|
### validation_end
|
||||||
|
@ -371,6 +382,9 @@ def val_dataloader(self):
|
||||||
return [loader_a, loader_b, ..., loader_n]
|
return [loader_a, loader_b, ..., loader_n]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
In the case where you return multiple val_dataloaders, the validation_step will have an arguement ```dataset_idx```
|
||||||
|
which matches the order here.
|
||||||
|
|
||||||
---
|
---
|
||||||
### test_dataloader
|
### test_dataloader
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue