diff --git a/CHANGELOG.md b/CHANGELOG.md index 00fa8d1bf8..2e800c2896 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,8 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562)) -- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787)) - ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/docs/source/sequences.rst b/docs/source/sequences.rst index b9a8f2ee64..e24ee5bbca 100644 --- a/docs/source/sequences.rst +++ b/docs/source/sequences.rst @@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag. .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. -.. note:: Using this feature requires updating your LightningModule's - :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. +.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include + a `hiddens` arg. ---------- @@ -59,13 +59,10 @@ Iterable Datasets Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural option when using sequential data. -.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int - (specifying the number of training batches to run before validation) when initializing the Trainer. This is - because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation - interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or - an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches`` - to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. - Here mode can be train/val/test. +.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int + (specifying the number of training batches to run before validation) when initializing the Trainer. + This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate + the validation interval when val_check_interval is less than one. .. testcode:: @@ -90,9 +87,3 @@ option when using sequential data. # Set val_check_interval trainer = Trainer(val_check_interval=100) - - # Set limit_val_batches to 0.0 or 0 - trainer = Trainer(limit_val_batches=0.0) - - # Set limit_val_batches as an int - trainer = Trainer(limit_val_batches=100) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c09d981d1d..80081c0dd4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1771,7 +1771,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod elif self.example_input_array is not None: input_data = self.example_input_array else: - raise ValueError('`input_sample` and `example_input_array` tensors are both missing.') + raise ValueError(f'input_sample and example_input_array tensors are both missing.') if 'example_outputs' not in kwargs: self.eval() diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 4eec847580..09186765c6 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -212,19 +212,18 @@ class TrainerDataLoadingMixin(ABC): # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') self._check_batch_limits('limit_train_batches') - if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: - self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) - elif self.num_training_batches != float('inf'): - self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) - elif self.limit_train_batches != 1.0: - raise MisconfigurationException( - 'When using an IterableDataset for `limit_train_batches`,' - ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' - ' `num_training_batches` to use.') + if not _has_len(self.train_dataloader): + self.num_training_batches = float('inf') + else: + # try getting the length + if isinstance(self.limit_train_batches, float): + self.num_training_batches = len(self.train_dataloader) + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + else: + self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches) # determine when to check validation # if int passed in, val checks that often @@ -242,7 +241,8 @@ class TrainerDataLoadingMixin(ABC): self.val_check_batch = float('inf') else: raise MisconfigurationException( - 'When using an IterableDataset for `train_dataloader`,' + 'When using an infinite DataLoader (e.g. with an IterableDataset' + ' or when DataLoader does not implement `__len__`) for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.') else: @@ -304,21 +304,24 @@ class TrainerDataLoadingMixin(ABC): for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if _has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') - self._check_batch_limits(f'limit_{mode}_batches') # percent or num_steps limit_eval_batches = getattr(self, f'limit_{mode}_batches') - # limit num batches either as a percent or num steps - if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: - num_batches = min(num_batches, int(limit_eval_batches)) - elif num_batches != float('inf'): - num_batches = int(num_batches * limit_eval_batches) - elif limit_eval_batches != 1.0: + if num_batches != float('inf'): + self._check_batch_limits(f'limit_{mode}_batches') + + # limit num batches either as a percent or num steps + if isinstance(limit_eval_batches, float): + num_batches = int(num_batches * limit_eval_batches) + else: + num_batches = min(len(dataloader), limit_eval_batches) + + elif limit_eval_batches not in (0.0, 1.0): raise MisconfigurationException( - 'When using an IterableDataset for `limit_{mode}_batches`,' - f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies' - f' `num_{mode}_batches` to use.') + 'When using an infinite DataLoader (e.g. with an IterableDataset' + f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,' + f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.') if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) diff --git a/tests/models/test_onnx_save.py b/tests/models/test_onnx_save.py index 7cb40561f7..f824f33c93 100644 --- a/tests/models/test_onnx_save.py +++ b/tests/models/test_onnx_save.py @@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir): model = EvalModelTemplate() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onxx") - with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'): + with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'): model.to_onnx(file_path) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1aad504785..1c7e21b7a7 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -256,69 +256,6 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' -@pytest.mark.parametrize( - ['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], - [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(1.0, 1.0, 1.0), - ] -) -def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, - limit_val_batches, limit_test_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=limit_train_batches, - limit_val_batches=limit_val_batches, - limit_test_batches=limit_test_batches, - ) - - results = trainer.fit(model) - assert results == 1 - assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf') - assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf') - - trainer.test(ckpt_path=None) - assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf') - - -@pytest.mark.parametrize( - ['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], - [ - pytest.param(0, 0, 0), - pytest.param(10, 10, 10), - ] -) -def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=limit_train_batches, - limit_val_batches=limit_val_batches, - limit_test_batches=limit_test_batches, - ) - - results = trainer.fit(model) - assert results - assert trainer.num_training_batches == limit_train_batches - assert trainer.num_val_batches[0] == limit_val_batches - - trainer.test(ckpt_path=None) - assert trainer.num_test_batches[0] == limit_test_batches - - @pytest.mark.parametrize( ['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ @@ -329,7 +266,7 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim ] ) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" + """Verify num_batches for val & test dataloaders passed with batch limit in percent""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -370,7 +307,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim ] ) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" + """Verify num_batches for val & test dataloaders passed with batch limit as number""" os.environ['PL_DEV_DEBUG'] = '1' model = EvalModelTemplate() @@ -499,7 +436,7 @@ def test_train_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.fit(model) @@ -510,7 +447,7 @@ def test_val_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.fit(model) @@ -521,7 +458,7 @@ def test_test_inf_dataloader_error(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.test(model) @@ -837,7 +774,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.fit(model) @@ -848,7 +785,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.fit(model) @@ -859,5 +796,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5) - with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.test(model)