Support for multiple val_dataloaders (#97)
* Added support for multiple validation dataloaders * Fix typo in README.md * Update trainer.py * Add support for multiple dataloaders * Rename dataloader_index to dataloader_i * Added warning to check val_dataloaders Added a warning to ensure that all val_dataloaders were DistributedSamplers if ddp is enabled * Updated DistributedSampler warning * Fixed typo * Added multiple val_dataloaders * Multiple val_dataloader test * Update lightning_module_template.py Added dataloader_i to validation_step parameters * Update trainer.py * Reverted template changes * Create multi_val_module.py * Update no_val_end_module.py * New MultiValModel * Rename MultiValModel to MultiValTestModel * Revert to LightningTestModel * Update test_models.py * Update trainer.py * Update test_models.py * multiple val_dataloaders in test template * Fixed flake8 warnings * Update trainer.py * Fix flake errors * Fixed Flake8 errors * Update lm_test_module.py keep this test model with a single dataset for val * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update RequiredTrainerInterface.md * Update RequiredTrainerInterface.md * Update test_models.py * Update trainer.py dont need the else clause, val_dataloader is either a list or none because of get_dataloaders() * Update trainer.py fixed flake errors * Update trainer.py
This commit is contained in:
parent
46e27e38aa
commit
511f7ecb9a
|
@ -46,24 +46,25 @@ class CoolModel(pl.LightningModule):
|
|||
def forward(self, x):
|
||||
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
|
||||
def my_loss(self, y_hat, y):
|
||||
return F.cross_entropy(y_hat, y)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
# REQUIRED
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'loss': self.my_loss(y_hat, y)}
|
||||
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'val_loss': self.my_loss(y_hat, y)}
|
||||
return {'val_loss': F.cross_entropy(y_hat, y)(y_hat, y)}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
# OPTIONAL
|
||||
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
return {'avg_val_loss': avg_loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@pl.data_loader
|
||||
|
@ -72,10 +73,13 @@ class CoolModel(pl.LightningModule):
|
|||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
# OPTIONAL
|
||||
# can also return a list of val dataloaders
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
# OPTIONAL
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
```
|
||||
---
|
||||
|
@ -88,7 +92,7 @@ The LightningModule interface is on the right. Each method corresponds to a part
|
|||
</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
## Required Methods
|
||||
|
||||
### training_step
|
||||
|
||||
|
@ -134,15 +138,75 @@ def training_step(self, data_batch, batch_nb):
|
|||
return output
|
||||
```
|
||||
|
||||
---
|
||||
---
|
||||
### tng_dataloader
|
||||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self)
|
||||
```
|
||||
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
|
||||
##### Return
|
||||
PyTorch DataLoader
|
||||
|
||||
**Example**
|
||||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
return loader
|
||||
```
|
||||
|
||||
---
|
||||
### configure_optimizers
|
||||
|
||||
``` {.python}
|
||||
def configure_optimizers(self)
|
||||
```
|
||||
|
||||
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
|
||||
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
|
||||
|
||||
|
||||
##### Return
|
||||
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
|
||||
|
||||
**Example**
|
||||
|
||||
``` {.python}
|
||||
# most cases
|
||||
def configure_optimizers(self):
|
||||
opt = Adam(self.parameters(), lr=0.01)
|
||||
return [opt]
|
||||
|
||||
# gan example, with scheduler for discriminator
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
||||
return [generator_opt, disriminator_opt], [discriminator_sched]
|
||||
```
|
||||
|
||||
## Optional Methods
|
||||
|
||||
### validation_step
|
||||
|
||||
``` {.python}
|
||||
def validation_step(self, data_batch, batch_nb)
|
||||
def validation_step(self, data_batch, batch_nb, dataloader_i)
|
||||
```
|
||||
**OPTIONAL**
|
||||
If you don't need to validate you don't need to implement this method.
|
||||
|
||||
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes, calculate accuracy, or save example outputs (using self.experiment or whatever you want). Really, anything you want.
|
||||
|
||||
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
|
||||
This is most likely the same as your training_step. But unlike training step, the outputs from here will go to validation_end for collation.
|
||||
|
||||
**Params**
|
||||
|
@ -151,6 +215,7 @@ This is most likely the same as your training_step. But unlike training step, th
|
|||
|---|---|
|
||||
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch_nb | Integer displaying which batch this is |
|
||||
| dataloader_i | Integer displaying which dataloader this is |
|
||||
|
||||
**Return**
|
||||
|
||||
|
@ -188,9 +253,12 @@ def validation_step(self, data_batch, batch_nb):
|
|||
|
||||
``` {.python}
|
||||
def validation_end(self, outputs)
|
||||
```
|
||||
```
|
||||
If you didn't define a validation_step, this won't be called.
|
||||
|
||||
Called at the end of the validation loop with the output of each validation_step.
|
||||
Called at the end of the validation loop with the output of each validation_step. Called once per validation dataset.
|
||||
|
||||
The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
|
||||
|
||||
**Params**
|
||||
|
||||
|
@ -225,36 +293,6 @@ def validation_end(self, outputs):
|
|||
return tqdm_dic
|
||||
```
|
||||
|
||||
---
|
||||
### configure_optimizers
|
||||
|
||||
``` {.python}
|
||||
def configure_optimizers(self)
|
||||
```
|
||||
|
||||
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
|
||||
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
|
||||
|
||||
|
||||
##### Return
|
||||
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
|
||||
|
||||
**Example**
|
||||
|
||||
``` {.python}
|
||||
# most cases
|
||||
def configure_optimizers(self):
|
||||
opt = Adam(self.parameters(), lr=0.01)
|
||||
return [opt]
|
||||
|
||||
# gan example, with scheduler for discriminator
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
||||
return [generator_opt, disriminator_opt], [discriminator_sched]
|
||||
```
|
||||
|
||||
---
|
||||
### on_save_checkpoint
|
||||
|
||||
|
@ -297,33 +335,6 @@ def on_load_checkpoint(self, checkpoint):
|
|||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||
```
|
||||
|
||||
---
|
||||
### tng_dataloader
|
||||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self)
|
||||
```
|
||||
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
|
||||
##### Return
|
||||
PyTorch DataLoader
|
||||
|
||||
**Example**
|
||||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
return loader
|
||||
```
|
||||
|
||||
---
|
||||
### val_dataloader
|
||||
|
||||
|
@ -331,10 +342,13 @@ def tng_dataloader(self):
|
|||
@pl.data_loader
|
||||
def tng_dataloader(self)
|
||||
```
|
||||
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
**OPTIONAL**
|
||||
If you don't need a validation dataset and a validation_step, you don't need to implement this method.
|
||||
|
||||
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
|
||||
##### Return
|
||||
PyTorch DataLoader
|
||||
PyTorch DataLoader or list of PyTorch Dataloaders.
|
||||
|
||||
**Example**
|
||||
|
||||
|
@ -350,6 +364,11 @@ def val_dataloader(self):
|
|||
)
|
||||
|
||||
return loader
|
||||
|
||||
# can also return multiple dataloaders
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -359,6 +378,9 @@ def val_dataloader(self):
|
|||
@pl.data_loader
|
||||
def test_dataloader(self)
|
||||
```
|
||||
**OPTIONAL**
|
||||
If you don't need a test dataset and a test_step, you don't need to implement this method.
|
||||
|
||||
Called by lightning during test loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
|
||||
##### Return
|
||||
|
|
|
@ -105,7 +105,7 @@ class LightningTemplateModel(LightningModule):
|
|||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
def validation_step(self, data_batch, batch_i, dataloader_i):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
|
@ -218,7 +218,7 @@ class LightningTemplateModel(LightningModule):
|
|||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
print('val data loader called')
|
||||
return self.__dataloader(train=False)
|
||||
return [self.__dataloader(train=False) for i in range(2)]
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
|
|
|
@ -354,7 +354,11 @@ class Trainer(TrainerIO):
|
|||
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
|
||||
|
||||
# determine number of validation batches
|
||||
self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0
|
||||
# val datasets could be none, 1 or 2+
|
||||
self.nb_val_batches = 0
|
||||
if self.val_dataloader is not None:
|
||||
self.nb_val_batches = sum(len(dataloader) for dataloader in self.val_dataloader)
|
||||
|
||||
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
|
||||
self.nb_val_batches = max(1, self.nb_val_batches)
|
||||
self.nb_val_batches = self.nb_val_batches
|
||||
|
@ -373,7 +377,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
self.tqdm_metrics[k] = v
|
||||
|
||||
def validate(self, model, dataloader, max_batches):
|
||||
def validate(self, model, dataloader, max_batches, dataloader_i):
|
||||
"""
|
||||
Run validation code
|
||||
:param model: PT model
|
||||
|
@ -381,9 +385,6 @@ class Trainer(TrainerIO):
|
|||
:param max_batches: Scalar
|
||||
:return:
|
||||
"""
|
||||
# skip validation if model has no validation_step defined
|
||||
if not self.__is_overriden('validation_step'):
|
||||
return {}
|
||||
|
||||
# enable eval mode
|
||||
model.zero_grad()
|
||||
|
@ -409,9 +410,9 @@ class Trainer(TrainerIO):
|
|||
# RUN VALIDATION STEP
|
||||
# -----------------
|
||||
if self.use_ddp:
|
||||
output = model(data_batch, batch_i)
|
||||
output = model(data_batch, batch_i, dataloader_i)
|
||||
elif self.use_dp:
|
||||
output = model(data_batch, batch_i)
|
||||
output = model(data_batch, batch_i, dataloader_i)
|
||||
elif self.single_gpu:
|
||||
# put inputs on gpu manually
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
|
@ -420,10 +421,10 @@ class Trainer(TrainerIO):
|
|||
data_batch[i] = x.cuda(gpu_id)
|
||||
|
||||
# do non dp, ddp step
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
output = model.validation_step(data_batch, batch_i, dataloader_i)
|
||||
|
||||
else:
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
output = model.validation_step(data_batch, batch_i, dataloader_i)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
|
@ -458,6 +459,11 @@ class Trainer(TrainerIO):
|
|||
self.test_dataloader = model.test_dataloader
|
||||
self.val_dataloader = model.val_dataloader
|
||||
|
||||
# handle returning an actual dataloader instead of a list of loaders
|
||||
have_val_loaders = self.val_dataloader is not None
|
||||
if have_val_loaders and not isinstance(self.val_dataloader, list):
|
||||
self.val_dataloader = [self.val_dataloader]
|
||||
|
||||
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
|
||||
msg = """
|
||||
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
||||
|
@ -473,6 +479,28 @@ dataset = myDataset()
|
|||
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||
|
||||
If you want each process to load the full dataset, ignore this warning.
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
|
||||
if self.use_ddp and\
|
||||
not all(isinstance(dataloader, DistributedSampler)
|
||||
for dataloader in self.val_dataloader):
|
||||
msg = """
|
||||
You're val_dataloader(s) are not all DistributedSamplers.
|
||||
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
||||
to assign a subset of your data to each process. To silence this warning, pass a
|
||||
DistributedSampler to your DataLoader.
|
||||
|
||||
ie: this:
|
||||
dataset = myDataset()
|
||||
dataloader = Dataloader(dataset)
|
||||
|
||||
becomes:
|
||||
dataset = myDataset()
|
||||
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||
|
||||
If you want each process to load the full dataset, ignore this warning.
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
|
@ -721,9 +749,11 @@ We recommend you switch to ddp if you want to use amp
|
|||
if self.cluster is not None: # pragma: no cover
|
||||
self.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
# run tiny validation (if validation defined) to make sure program won't crash during val
|
||||
ref_model.on_sanity_check_start()
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
if self.val_dataloader is not None:
|
||||
for ds_i, dataloader in enumerate(self.val_dataloader):
|
||||
self.validate(model, dataloader, self.nb_sanity_val_steps, ds_i)
|
||||
|
||||
# ---------------------------
|
||||
# CORE TRAINING LOOP
|
||||
|
@ -988,30 +1018,30 @@ We recommend you switch to ddp if you want to use amp
|
|||
elif not can_check_epoch:
|
||||
return
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_pre_performance_check'):
|
||||
model = self.__get_model()
|
||||
model.on_pre_performance_check()
|
||||
# validate only if model has validation_step defined
|
||||
if self.__is_overriden('validation_step'):
|
||||
|
||||
# use full val set on end of epoch
|
||||
# use a small portion otherwise
|
||||
max_batches = None if not self.fast_dev_run else 1
|
||||
validation_results = self.validate(
|
||||
self.model,
|
||||
self.val_dataloader,
|
||||
max_batches
|
||||
)
|
||||
self.__add_tqdm_metrics(validation_results)
|
||||
# hook
|
||||
if self.__is_function_implemented('on_pre_performance_check'):
|
||||
model = self.__get_model()
|
||||
model.on_pre_performance_check()
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_post_performance_check'):
|
||||
model = self.__get_model()
|
||||
model.on_post_performance_check()
|
||||
# use full val set on end of epoch
|
||||
# use a small portion otherwise
|
||||
max_batches = None if not self.fast_dev_run else 1
|
||||
for ds_i, dataloader in enumerate(self.val_dataloader):
|
||||
val_out_metrics = self.validate(self.model, dataloader, max_batches, ds_i)
|
||||
self.__add_tqdm_metrics(val_out_metrics)
|
||||
|
||||
if self.progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
# hook
|
||||
if self.__is_function_implemented('on_post_performance_check'):
|
||||
model = self.__get_model()
|
||||
model.on_post_performance_check()
|
||||
|
||||
if self.progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# model checkpointing
|
||||
if self.proc_rank == 0 and self.checkpoint_callback is not None:
|
||||
|
|
|
@ -109,7 +109,7 @@ class LightningTestModel(LightningModule):
|
|||
if self.trainer.batch_nb % 2 == 0:
|
||||
return loss_val
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
def validation_step(self, data_batch, batch_i, dataloader_i):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
|
@ -151,6 +151,12 @@ class LightningTestModel(LightningModule):
|
|||
'test_dic': {'val_loss_a': loss_val}
|
||||
})
|
||||
return output
|
||||
if batch_i % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'val_loss_{dataloader_i}': loss_val,
|
||||
f'val_acc_{dataloader_i}': val_acc,
|
||||
})
|
||||
return output
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""
|
||||
|
|
|
@ -109,7 +109,7 @@ class NoValEndTestModel(LightningModule):
|
|||
if self.trainer.batch_nb % 2 == 0:
|
||||
return loss_val
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
def validation_step(self, data_batch, batch_i, dataloader_i):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
|
|
|
@ -252,7 +252,7 @@ def test_cpu_restore_training():
|
|||
# if model and state loaded correctly, predictions will be good even though we
|
||||
# haven't trained with the new loaded model
|
||||
trainer.model.eval()
|
||||
run_prediction(trainer.val_dataloader, trainer.model)
|
||||
_ = [run_prediction(dataloader, trainer.model) for dataloader in trainer.val_dataloader]
|
||||
|
||||
model.on_sanity_check_start = assert_good_acc
|
||||
|
||||
|
@ -761,6 +761,37 @@ def test_ddp_sampler_error():
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def test_multiple_val_dataloader():
|
||||
"""
|
||||
Verify multiple val_dataloader
|
||||
:return:
|
||||
"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTemplateModel(hparams)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.1,
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify tng completed
|
||||
assert result == 1
|
||||
|
||||
# verify there are 2 val loaders
|
||||
assert len(trainer.val_dataloader) == 2, 'Multiple val_dataloaders not initiated properly'
|
||||
|
||||
# make sure predictions are good for each val set
|
||||
[run_prediction(dataloader, trainer.model) for dataloader in trainer.val_dataloader]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# UTILS
|
||||
# ------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue