diff --git a/CHANGELOG.md b/CHANGELOG.md index d4a848bd4d..90c9f49566 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) +- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e848e09725..dced28144f 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -136,16 +136,19 @@ class TrainerDataLoadingMixin(ABC): 'If you want to disable validation set `val_percent_check` to 0.0 instead.') else: if not _has_len(self.train_dataloader): - raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' - 'DataLoader does not implement `__len__`) for `train_dataloader`, ' - '`Trainer(val_check_interval)` must be an int. An int k specifies checking ' - 'validation every k training batches.') + if self.val_check_interval == 1.0: + self.val_check_batch = float('inf') + else: + raise MisconfigurationException( + 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' + 'DataLoader does not implement `__len__`) for `train_dataloader`, ' + '`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies ' + 'checking validation every k training batches.') + else: + self._percent_range_check('val_check_interval') - self._percent_range_check('val_check_interval') - - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) - self.val_check_batch = max(1, self.val_check_batch) + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) + self.val_check_batch = max(1, self.val_check_batch) def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a90a43b0d8..8cfcd70433 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -400,8 +400,8 @@ class TrainerTrainLoopMixin(ABC): train_dataloader = train_dataloader.per_device_loader(device) # run epoch - for batch_idx, batch in self.profiler.profile_iterable( - enumerate(train_dataloader), "get_train_batch" + for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( + enumerate(_with_is_last(train_dataloader)), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: @@ -429,8 +429,10 @@ class TrainerTrainLoopMixin(ABC): # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - should_check_val = not self.disable_validation and can_check_epoch - should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch) + can_check_val = not self.disable_validation and can_check_epoch + should_check_val = is_val_check_batch or early_stop_epoch + should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) + should_check_val = can_check_val and should_check_val # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: @@ -740,3 +742,16 @@ class TrainerTrainLoopMixin(ABC): if self.checkpoint_callback is not None: self.checkpoint_callback.on_validation_end(self, self.get_model()) self.on_validation_end() + + +def _with_is_last(iterable): + """Pass through values from the given iterable with an added boolean indicating if this is the last item. + See `https://stackoverflow.com/a/1630350 `_""" + it = iter(iterable) + last = next(it) + for val in it: + # yield last and has next + yield last, False + last = val + # yield last, no longer has next + yield last, True diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6f0ee15aef..194d67fe07 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -372,7 +372,6 @@ def test_inf_train_dataloader(tmpdir): ) trainer.fit(model) - # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1, @@ -383,6 +382,15 @@ def test_inf_train_dataloader(tmpdir): # verify training completed assert result == 1 + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1 + ) + result = trainer.fit(model) + + # verify training completed + assert result == 1 + def test_inf_val_dataloader(tmpdir): """Test inf val data loader (e.g. IterableDataset)"""