Add support for iterable datasets when val_check_interval=1.0 (#1283)

* Add support for iterable datasets when val_check_interval=1.0

* Update CHANGELOG.md
This commit is contained in:
Ethan Harris 2020-03-29 20:27:44 +01:00 committed by GitHub
parent 54507f417e
commit ab09faa15e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 14 deletions

View File

@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
### Changed

View File

@ -136,16 +136,19 @@ class TrainerDataLoadingMixin(ABC):
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
'validation every k training batches.')
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
'checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')
self._percent_range_check('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:

View File

@ -400,8 +400,8 @@ class TrainerTrainLoopMixin(ABC):
train_dataloader = train_dataloader.per_device_loader(device)
# run epoch
for batch_idx, batch in self.profiler.profile_iterable(
enumerate(train_dataloader), "get_train_batch"
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
@ -429,8 +429,10 @@ class TrainerTrainLoopMixin(ABC):
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = not self.disable_validation and can_check_epoch
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
@ -740,3 +742,16 @@ class TrainerTrainLoopMixin(ABC):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()
def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True

View File

@ -372,7 +372,6 @@ def test_inf_train_dataloader(tmpdir):
)
trainer.fit(model)
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
@ -383,6 +382,15 @@ def test_inf_train_dataloader(tmpdir):
# verify training completed
assert result == 1
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
# verify training completed
assert result == 1
def test_inf_val_dataloader(tmpdir):
"""Test inf val data loader (e.g. IterableDataset)"""