Revert "Support limit_mode_batches (int) for infinite dataloader" (#2839)

* Revert "Support limit_mode_batches (int) for infinite dataloader (#2787)"

This reverts commit de9c9f0864.

* Update training_tricks.py
This commit is contained in:
William Falcon 2020-08-05 15:57:26 -04:00 committed by GitHub
parent 2242af11b6
commit 5d0f0325d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 111 deletions

View File

@ -33,8 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))
### Changed
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

View File

@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
a `hiddens` arg.
----------
@ -59,13 +59,10 @@ Iterable Datasets
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here mode can be train/val/test.
.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
(specifying the number of training batches to run before validation) when initializing the Trainer.
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
the validation interval when val_check_interval is less than one.
.. testcode::
@ -90,9 +87,3 @@ option when using sequential data.
# Set val_check_interval
trainer = Trainer(val_check_interval=100)
# Set limit_val_batches to 0.0 or 0
trainer = Trainer(limit_val_batches=0.0)
# Set limit_val_batches as an int
trainer = Trainer(limit_val_batches=100)

View File

@ -1771,7 +1771,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
if 'example_outputs' not in kwargs:
self.eval()

View File

@ -212,19 +212,18 @@ class TrainerDataLoadingMixin(ABC):
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
if isinstance(self.limit_train_batches, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
# determine when to check validation
# if int passed in, val checks that often
@ -242,7 +241,8 @@ class TrainerDataLoadingMixin(ABC):
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an IterableDataset for `train_dataloader`,'
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
@ -304,21 +304,24 @@ class TrainerDataLoadingMixin(ABC):
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
self._check_batch_limits(f'limit_{mode}_batches')
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
if num_batches != float('inf'):
self._check_batch_limits(f'limit_{mode}_batches')
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = min(len(dataloader), limit_eval_batches)
elif limit_eval_batches not in (0.0, 1.0):
raise MisconfigurationException(
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)

View File

@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir):
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
model.to_onnx(file_path)

View File

@ -256,69 +256,6 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
pytest.param(1.0, 1.0, 1.0),
]
)
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
results = trainer.fit(model)
assert results == 1
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')
trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')
@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0, 0, 0),
pytest.param(10, 10, 10),
]
)
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
results = trainer.fit(model)
assert results
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches
trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == limit_test_batches
@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
@ -329,7 +266,7 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
@ -370,7 +307,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'
model = EvalModelTemplate()
@ -499,7 +436,7 @@ def test_train_inf_dataloader_error(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
@ -510,7 +447,7 @@ def test_val_inf_dataloader_error(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
@ -521,7 +458,7 @@ def test_test_inf_dataloader_error(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.test(model)
@ -837,7 +774,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
@ -848,7 +785,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
@ -859,5 +796,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.test(model)