Add support for iterable datasets when val_check_interval=1.0 (#1283)
* Add support for iterable datasets when val_check_interval=1.0 * Update CHANGELOG.md
This commit is contained in:
parent
54507f417e
commit
ab09faa15e
|
@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
|
||||
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
|
||||
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
|
||||
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -136,16 +136,19 @@ class TrainerDataLoadingMixin(ABC):
|
|||
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
|
||||
else:
|
||||
if not _has_len(self.train_dataloader):
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
|
||||
'DataLoader does not implement `__len__`) for `train_dataloader`, '
|
||||
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
|
||||
'validation every k training batches.')
|
||||
if self.val_check_interval == 1.0:
|
||||
self.val_check_batch = float('inf')
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
|
||||
'DataLoader does not implement `__len__`) for `train_dataloader`, '
|
||||
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
|
||||
'checking validation every k training batches.')
|
||||
else:
|
||||
self._percent_range_check('val_check_interval')
|
||||
|
||||
self._percent_range_check('val_check_interval')
|
||||
|
||||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
def _reset_eval_dataloader(self, model: LightningModule,
|
||||
mode: str) -> Tuple[int, List[DataLoader]]:
|
||||
|
|
|
@ -400,8 +400,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
train_dataloader = train_dataloader.per_device_loader(device)
|
||||
|
||||
# run epoch
|
||||
for batch_idx, batch in self.profiler.profile_iterable(
|
||||
enumerate(train_dataloader), "get_train_batch"
|
||||
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
|
||||
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
|
||||
):
|
||||
# stop epoch if we limited the number of training batches
|
||||
if batch_idx >= self.num_training_batches:
|
||||
|
@ -429,8 +429,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# ---------------
|
||||
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
|
||||
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
||||
should_check_val = not self.disable_validation and can_check_epoch
|
||||
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)
|
||||
can_check_val = not self.disable_validation and can_check_epoch
|
||||
should_check_val = is_val_check_batch or early_stop_epoch
|
||||
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
|
||||
should_check_val = can_check_val and should_check_val
|
||||
|
||||
# fast_dev_run always forces val checking after train batch
|
||||
if self.fast_dev_run or should_check_val:
|
||||
|
@ -740,3 +742,16 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.on_validation_end(self, self.get_model())
|
||||
self.on_validation_end()
|
||||
|
||||
|
||||
def _with_is_last(iterable):
|
||||
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
|
||||
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
|
||||
it = iter(iterable)
|
||||
last = next(it)
|
||||
for val in it:
|
||||
# yield last and has next
|
||||
yield last, False
|
||||
last = val
|
||||
# yield last, no longer has next
|
||||
yield last, True
|
||||
|
|
|
@ -372,7 +372,6 @@ def test_inf_train_dataloader(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
|
@ -383,6 +382,15 @@ def test_inf_train_dataloader(tmpdir):
|
|||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_inf_val_dataloader(tmpdir):
|
||||
"""Test inf val data loader (e.g. IterableDataset)"""
|
||||
|
|
Loading…
Reference in New Issue