Revert "Support limit_mode_batches (int) for infinite dataloader" (#2839)
* Revert "Support limit_mode_batches (int) for infinite dataloader (#2787)"
This reverts commit de9c9f0864
.
* Update training_tricks.py
This commit is contained in:
parent
2242af11b6
commit
5d0f0325d8
|
@ -33,8 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
|
||||
|
||||
- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))
|
||||
|
||||
### Changed
|
||||
|
||||
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
|
||||
|
|
|
@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
|
|||
.. note:: If you need to modify how the batch is split,
|
||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
||||
|
||||
.. note:: Using this feature requires updating your LightningModule's
|
||||
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
|
||||
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
|
||||
a `hiddens` arg.
|
||||
|
||||
----------
|
||||
|
||||
|
@ -59,13 +59,10 @@ Iterable Datasets
|
|||
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
|
||||
option when using sequential data.
|
||||
|
||||
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
|
||||
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
|
||||
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
|
||||
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
|
||||
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
|
||||
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
|
||||
Here mode can be train/val/test.
|
||||
.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
|
||||
(specifying the number of training batches to run before validation) when initializing the Trainer.
|
||||
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
|
||||
the validation interval when val_check_interval is less than one.
|
||||
|
||||
.. testcode::
|
||||
|
||||
|
@ -90,9 +87,3 @@ option when using sequential data.
|
|||
|
||||
# Set val_check_interval
|
||||
trainer = Trainer(val_check_interval=100)
|
||||
|
||||
# Set limit_val_batches to 0.0 or 0
|
||||
trainer = Trainer(limit_val_batches=0.0)
|
||||
|
||||
# Set limit_val_batches as an int
|
||||
trainer = Trainer(limit_val_batches=100)
|
||||
|
|
|
@ -1771,7 +1771,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
elif self.example_input_array is not None:
|
||||
input_data = self.example_input_array
|
||||
else:
|
||||
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')
|
||||
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
|
||||
|
||||
if 'example_outputs' not in kwargs:
|
||||
self.eval()
|
||||
|
|
|
@ -212,19 +212,18 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# automatically add samplers
|
||||
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
||||
|
||||
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
|
||||
self._worker_check(self.train_dataloader, 'train dataloader')
|
||||
self._check_batch_limits('limit_train_batches')
|
||||
|
||||
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
|
||||
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
|
||||
elif self.num_training_batches != float('inf'):
|
||||
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
|
||||
elif self.limit_train_batches != 1.0:
|
||||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `limit_train_batches`,'
|
||||
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
||||
' `num_training_batches` to use.')
|
||||
if not _has_len(self.train_dataloader):
|
||||
self.num_training_batches = float('inf')
|
||||
else:
|
||||
# try getting the length
|
||||
if isinstance(self.limit_train_batches, float):
|
||||
self.num_training_batches = len(self.train_dataloader)
|
||||
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
|
||||
else:
|
||||
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
|
||||
|
||||
# determine when to check validation
|
||||
# if int passed in, val checks that often
|
||||
|
@ -242,7 +241,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.val_check_batch = float('inf')
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `train_dataloader`,'
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
|
||||
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
|
||||
' checking validation every k training batches.')
|
||||
else:
|
||||
|
@ -304,21 +304,24 @@ class TrainerDataLoadingMixin(ABC):
|
|||
for i, dataloader in enumerate(dataloaders):
|
||||
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
|
||||
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
||||
self._check_batch_limits(f'limit_{mode}_batches')
|
||||
|
||||
# percent or num_steps
|
||||
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
|
||||
|
||||
# limit num batches either as a percent or num steps
|
||||
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
|
||||
num_batches = min(num_batches, int(limit_eval_batches))
|
||||
elif num_batches != float('inf'):
|
||||
num_batches = int(num_batches * limit_eval_batches)
|
||||
elif limit_eval_batches != 1.0:
|
||||
if num_batches != float('inf'):
|
||||
self._check_batch_limits(f'limit_{mode}_batches')
|
||||
|
||||
# limit num batches either as a percent or num steps
|
||||
if isinstance(limit_eval_batches, float):
|
||||
num_batches = int(num_batches * limit_eval_batches)
|
||||
else:
|
||||
num_batches = min(len(dataloader), limit_eval_batches)
|
||||
|
||||
elif limit_eval_batches not in (0.0, 1.0):
|
||||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `limit_{mode}_batches`,'
|
||||
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
||||
f' `num_{mode}_batches` to use.')
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
|
||||
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
|
||||
|
||||
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
|
||||
min_pct = 1.0 / len(dataloader)
|
||||
|
|
|
@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir):
|
|||
model = EvalModelTemplate()
|
||||
model.example_input_array = None
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
|
||||
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
|
||||
model.to_onnx(file_path)
|
||||
|
||||
|
||||
|
|
|
@ -256,69 +256,6 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
|
|||
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
|
||||
[
|
||||
pytest.param(0.0, 0.0, 0.0),
|
||||
pytest.param(1.0, 1.0, 1.0),
|
||||
]
|
||||
)
|
||||
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
|
||||
limit_val_batches, limit_test_batches):
|
||||
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
|
||||
model = EvalModelTemplate()
|
||||
model.train_dataloader = model.train_dataloader__infinite
|
||||
model.val_dataloader = model.val_dataloader__infinite
|
||||
model.test_dataloader = model.test_dataloader__infinite
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=limit_val_batches,
|
||||
limit_test_batches=limit_test_batches,
|
||||
)
|
||||
|
||||
results = trainer.fit(model)
|
||||
assert results == 1
|
||||
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
|
||||
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')
|
||||
|
||||
trainer.test(ckpt_path=None)
|
||||
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
|
||||
[
|
||||
pytest.param(0, 0, 0),
|
||||
pytest.param(10, 10, 10),
|
||||
]
|
||||
)
|
||||
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
||||
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
|
||||
model = EvalModelTemplate()
|
||||
model.train_dataloader = model.train_dataloader__infinite
|
||||
model.val_dataloader = model.val_dataloader__infinite
|
||||
model.test_dataloader = model.test_dataloader__infinite
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=limit_val_batches,
|
||||
limit_test_batches=limit_test_batches,
|
||||
)
|
||||
|
||||
results = trainer.fit(model)
|
||||
assert results
|
||||
assert trainer.num_training_batches == limit_train_batches
|
||||
assert trainer.num_val_batches[0] == limit_val_batches
|
||||
|
||||
trainer.test(ckpt_path=None)
|
||||
assert trainer.num_test_batches[0] == limit_test_batches
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
|
||||
[
|
||||
|
@ -329,7 +266,7 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim
|
|||
]
|
||||
)
|
||||
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
||||
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
|
||||
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
|
||||
model = EvalModelTemplate()
|
||||
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
||||
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
||||
|
@ -370,7 +307,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
|
|||
]
|
||||
)
|
||||
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
||||
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
|
||||
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -499,7 +436,7 @@ def test_train_inf_dataloader_error(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
@ -510,7 +447,7 @@ def test_val_inf_dataloader_error(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
@ -521,7 +458,7 @@ def test_test_inf_dataloader_error(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
|
@ -837,7 +774,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
@ -848,7 +785,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
@ -859,5 +796,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.test(model)
|
||||
|
|
Loading…
Reference in New Issue