Update docs iterable datasets (#1281)

* Updated Sequencial Data docs

* Sequntial Data section now contains info on using IterableDatasets

* * Undid reformatting of bullet points

* * added information about val_check_interval

Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com>
This commit is contained in:
Donal Byrne 2020-03-29 20:29:09 +01:00 committed by GitHub
parent ab09faa15e
commit fb42872259
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 33 additions and 1 deletions

View File

@ -42,4 +42,36 @@ Lightning can handle TBTT automatically via this flag.
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
a `hiddens` arg.
a `hiddens` arg.
Iterable Datasets
---------------------------------------
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
option when using sequential data.
.. note:: When using an IterableDataset you must set the val_check_interval to an int (specifying the number of training
batches to run before validation) when initializing the Trainer even when there is no validation logic in place.
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
the default validation interval.
.. code-block:: python
# IterableDataset
class CustomDataset(IterableDataset):
def __init__(self, data):
self.data_source
def __iter__(self):
return iter(self.data_source)
# Setup DataLoader
def train_dataloader(self):
seq_data = ['A', 'long', 'time', 'ago', 'in', 'a', 'galaxy', 'far', 'far', 'away']
iterable_dataset = CustomDataset(seq_data)
dataloader = DataLoader(dataset=iterable_dataset, batch_size=5)
return dataloader
# Set val_check_interval
trainer = pl.Trainer(val_check_interval=1000)