Support for multiple val_dataloaders (#97)

* Added support for multiple validation dataloaders

* Fix typo in README.md

* Update trainer.py

* Add support for multiple dataloaders

* Rename dataloader_index to dataloader_i

* Added warning to check val_dataloaders

Added a warning to ensure that all val_dataloaders were DistributedSamplers if ddp is enabled

* Updated DistributedSampler warning

* Fixed typo

* Added multiple val_dataloaders

* Multiple val_dataloader test

* Update lightning_module_template.py

Added dataloader_i to validation_step parameters

* Update trainer.py

* Reverted template changes

* Create multi_val_module.py

* Update no_val_end_module.py

* New MultiValModel

* Rename MultiValModel to MultiValTestModel

* Revert to LightningTestModel

* Update test_models.py

* Update trainer.py

* Update test_models.py

* multiple val_dataloaders in test template

* Fixed flake8 warnings

* Update trainer.py

* Fix flake errors

* Fixed Flake8 errors

* Update lm_test_module.py

keep this test model with a single dataset for val

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update test_models.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update RequiredTrainerInterface.md

* Update RequiredTrainerInterface.md

* Update test_models.py

* Update trainer.py

dont need the else clause, val_dataloader is either a list or none because of get_dataloaders()

* Update trainer.py

fixed flake errors

* Update trainer.py
This commit is contained in:
Sidhanth Holalkere 2019-08-12 15:23:11 -04:00 committed by William Falcon
parent 46e27e38aa
commit 511f7ecb9a
6 changed files with 196 additions and 107 deletions

View File

@ -46,24 +46,25 @@ class CoolModel(pl.LightningModule):
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': self.my_loss(y_hat, y)}
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}
return {'val_loss': F.cross_entropy(y_hat, y)(y_hat, y)}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@pl.data_loader
@ -72,10 +73,13 @@ class CoolModel(pl.LightningModule):
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
# can also return a list of val dataloaders
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
```
---
@ -88,7 +92,7 @@ The LightningModule interface is on the right. Each method corresponds to a part
</a>
</p>
---
## Required Methods
### training_step
@ -134,15 +138,75 @@ def training_step(self, data_batch, batch_nb):
return output
```
---
---
### tng_dataloader
``` {.python}
@pl.data_loader
def tng_dataloader(self)
```
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
##### Return
PyTorch DataLoader
**Example**
``` {.python}
@pl.data_loader
def tng_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
```
---
### configure_optimizers
``` {.python}
def configure_optimizers(self)
```
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
##### Return
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
**Example**
``` {.python}
# most cases
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=0.01)
return [opt]
# gan example, with scheduler for discriminator
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
```
## Optional Methods
### validation_step
``` {.python}
def validation_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb, dataloader_i)
```
**OPTIONAL**
If you don't need to validate you don't need to implement this method.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes, calculate accuracy, or save example outputs (using self.experiment or whatever you want). Really, anything you want.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
This is most likely the same as your training_step. But unlike training step, the outputs from here will go to validation_end for collation.
**Params**
@ -151,6 +215,7 @@ This is most likely the same as your training_step. But unlike training step, th
|---|---|
| data_batch | The output of your dataloader. A tensor, tuple or list |
| batch_nb | Integer displaying which batch this is |
| dataloader_i | Integer displaying which dataloader this is |
**Return**
@ -188,9 +253,12 @@ def validation_step(self, data_batch, batch_nb):
``` {.python}
def validation_end(self, outputs)
```
```
If you didn't define a validation_step, this won't be called.
Called at the end of the validation loop with the output of each validation_step.
Called at the end of the validation loop with the output of each validation_step. Called once per validation dataset.
The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
**Params**
@ -225,36 +293,6 @@ def validation_end(self, outputs):
return tqdm_dic
```
---
### configure_optimizers
``` {.python}
def configure_optimizers(self)
```
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
##### Return
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
**Example**
``` {.python}
# most cases
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=0.01)
return [opt]
# gan example, with scheduler for discriminator
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
```
---
### on_save_checkpoint
@ -297,33 +335,6 @@ def on_load_checkpoint(self, checkpoint):
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
```
---
### tng_dataloader
``` {.python}
@pl.data_loader
def tng_dataloader(self)
```
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
##### Return
PyTorch DataLoader
**Example**
``` {.python}
@pl.data_loader
def tng_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
```
---
### val_dataloader
@ -331,10 +342,13 @@ def tng_dataloader(self):
@pl.data_loader
def tng_dataloader(self)
```
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
**OPTIONAL**
If you don't need a validation dataset and a validation_step, you don't need to implement this method.
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
##### Return
PyTorch DataLoader
PyTorch DataLoader or list of PyTorch Dataloaders.
**Example**
@ -350,6 +364,11 @@ def val_dataloader(self):
)
return loader
# can also return multiple dataloaders
@pl.data_loader
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
```
---
@ -359,6 +378,9 @@ def val_dataloader(self):
@pl.data_loader
def test_dataloader(self)
```
**OPTIONAL**
If you don't need a test dataset and a test_step, you don't need to implement this method.
Called by lightning during test loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
##### Return

View File

@ -105,7 +105,7 @@ class LightningTemplateModel(LightningModule):
# can also return just a scalar instead of a dict (return loss_val)
return output
def validation_step(self, data_batch, batch_i):
def validation_step(self, data_batch, batch_i, dataloader_i):
"""
Lightning calls this inside the validation loop
:param data_batch:
@ -218,7 +218,7 @@ class LightningTemplateModel(LightningModule):
@pl.data_loader
def val_dataloader(self):
print('val data loader called')
return self.__dataloader(train=False)
return [self.__dataloader(train=False) for i in range(2)]
@pl.data_loader
def test_dataloader(self):

View File

@ -354,7 +354,11 @@ class Trainer(TrainerIO):
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
# determine number of validation batches
self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0
# val datasets could be none, 1 or 2+
self.nb_val_batches = 0
if self.val_dataloader is not None:
self.nb_val_batches = sum(len(dataloader) for dataloader in self.val_dataloader)
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches
@ -373,7 +377,7 @@ class Trainer(TrainerIO):
self.tqdm_metrics[k] = v
def validate(self, model, dataloader, max_batches):
def validate(self, model, dataloader, max_batches, dataloader_i):
"""
Run validation code
:param model: PT model
@ -381,9 +385,6 @@ class Trainer(TrainerIO):
:param max_batches: Scalar
:return:
"""
# skip validation if model has no validation_step defined
if not self.__is_overriden('validation_step'):
return {}
# enable eval mode
model.zero_grad()
@ -409,9 +410,9 @@ class Trainer(TrainerIO):
# RUN VALIDATION STEP
# -----------------
if self.use_ddp:
output = model(data_batch, batch_i)
output = model(data_batch, batch_i, dataloader_i)
elif self.use_dp:
output = model(data_batch, batch_i)
output = model(data_batch, batch_i, dataloader_i)
elif self.single_gpu:
# put inputs on gpu manually
gpu_id = self.data_parallel_device_ids[0]
@ -420,10 +421,10 @@ class Trainer(TrainerIO):
data_batch[i] = x.cuda(gpu_id)
# do non dp, ddp step
output = model.validation_step(data_batch, batch_i)
output = model.validation_step(data_batch, batch_i, dataloader_i)
else:
output = model.validation_step(data_batch, batch_i)
output = model.validation_step(data_batch, batch_i, dataloader_i)
outputs.append(output)
@ -458,6 +459,11 @@ class Trainer(TrainerIO):
self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader
# handle returning an actual dataloader instead of a list of loaders
have_val_loaders = self.val_dataloader is not None
if have_val_loaders and not isinstance(self.val_dataloader, list):
self.val_dataloader = [self.val_dataloader]
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
@ -473,6 +479,28 @@ dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
warnings.warn(msg)
if self.use_ddp and\
not all(isinstance(dataloader, DistributedSampler)
for dataloader in self.val_dataloader):
msg = """
You're val_dataloader(s) are not all DistributedSamplers.
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
warnings.warn(msg)
@ -721,9 +749,11 @@ We recommend you switch to ddp if you want to use amp
if self.cluster is not None: # pragma: no cover
self.enable_auto_hpc_walltime_manager()
# run tiny validation to make sure program won't crash during val
# run tiny validation (if validation defined) to make sure program won't crash during val
ref_model.on_sanity_check_start()
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
if self.val_dataloader is not None:
for ds_i, dataloader in enumerate(self.val_dataloader):
self.validate(model, dataloader, self.nb_sanity_val_steps, ds_i)
# ---------------------------
# CORE TRAINING LOOP
@ -988,30 +1018,30 @@ We recommend you switch to ddp if you want to use amp
elif not can_check_epoch:
return
# hook
if self.__is_function_implemented('on_pre_performance_check'):
model = self.__get_model()
model.on_pre_performance_check()
# validate only if model has validation_step defined
if self.__is_overriden('validation_step'):
# use full val set on end of epoch
# use a small portion otherwise
max_batches = None if not self.fast_dev_run else 1
validation_results = self.validate(
self.model,
self.val_dataloader,
max_batches
)
self.__add_tqdm_metrics(validation_results)
# hook
if self.__is_function_implemented('on_pre_performance_check'):
model = self.__get_model()
model.on_pre_performance_check()
# hook
if self.__is_function_implemented('on_post_performance_check'):
model = self.__get_model()
model.on_post_performance_check()
# use full val set on end of epoch
# use a small portion otherwise
max_batches = None if not self.fast_dev_run else 1
for ds_i, dataloader in enumerate(self.val_dataloader):
val_out_metrics = self.validate(self.model, dataloader, max_batches, ds_i)
self.__add_tqdm_metrics(val_out_metrics)
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# hook
if self.__is_function_implemented('on_post_performance_check'):
model = self.__get_model()
model.on_post_performance_check()
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None:

View File

@ -109,7 +109,7 @@ class LightningTestModel(LightningModule):
if self.trainer.batch_nb % 2 == 0:
return loss_val
def validation_step(self, data_batch, batch_i):
def validation_step(self, data_batch, batch_i, dataloader_i):
"""
Lightning calls this inside the validation loop
:param data_batch:
@ -151,6 +151,12 @@ class LightningTestModel(LightningModule):
'test_dic': {'val_loss_a': loss_val}
})
return output
if batch_i % 5 == 0:
output = OrderedDict({
f'val_loss_{dataloader_i}': loss_val,
f'val_acc_{dataloader_i}': val_acc,
})
return output
def validation_end(self, outputs):
"""

View File

@ -109,7 +109,7 @@ class NoValEndTestModel(LightningModule):
if self.trainer.batch_nb % 2 == 0:
return loss_val
def validation_step(self, data_batch, batch_i):
def validation_step(self, data_batch, batch_i, dataloader_i):
"""
Lightning calls this inside the validation loop
:param data_batch:

View File

@ -252,7 +252,7 @@ def test_cpu_restore_training():
# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
trainer.model.eval()
run_prediction(trainer.val_dataloader, trainer.model)
_ = [run_prediction(dataloader, trainer.model) for dataloader in trainer.val_dataloader]
model.on_sanity_check_start = assert_good_acc
@ -761,6 +761,37 @@ def test_ddp_sampler_error():
clear_save_dir()
def test_multiple_val_dataloader():
"""
Verify multiple val_dataloader
:return:
"""
hparams = get_hparams()
model = LightningTemplateModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
trainer_options = dict(
max_nb_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# verify tng completed
assert result == 1
# verify there are 2 val loaders
assert len(trainer.val_dataloader) == 2, 'Multiple val_dataloaders not initiated properly'
# make sure predictions are good for each val set
[run_prediction(dataloader, trainer.model) for dataloader in trainer.val_dataloader]
# ------------------------------------------------------------------------
# UTILS
# ------------------------------------------------------------------------