From fb42872259aeb90deb04c9e2110b992599979634 Mon Sep 17 00:00:00 2001 From: Donal Byrne Date: Sun, 29 Mar 2020 20:29:09 +0100 Subject: [PATCH] Update docs iterable datasets (#1281) * Updated Sequencial Data docs * Sequntial Data section now contains info on using IterableDatasets * * Undid reformatting of bullet points * * added information about val_check_interval Co-authored-by: Donal Byrne --- docs/source/sequences.rst | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/source/sequences.rst b/docs/source/sequences.rst index 3e0e064bcd..5e52dbf4ae 100644 --- a/docs/source/sequences.rst +++ b/docs/source/sequences.rst @@ -42,4 +42,36 @@ Lightning can handle TBTT automatically via this flag. override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. .. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include - a `hiddens` arg. \ No newline at end of file + a `hiddens` arg. + +Iterable Datasets +--------------------------------------- +Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural +option when using sequential data. + +.. note:: When using an IterableDataset you must set the val_check_interval to an int (specifying the number of training + batches to run before validation) when initializing the Trainer even when there is no validation logic in place. + This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate + the default validation interval. + +.. code-block:: python + + # IterableDataset + class CustomDataset(IterableDataset): + + def __init__(self, data): + self.data_source + + def __iter__(self): + return iter(self.data_source) + + # Setup DataLoader + def train_dataloader(self): + seq_data = ['A', 'long', 'time', 'ago', 'in', 'a', 'galaxy', 'far', 'far', 'away'] + iterable_dataset = CustomDataset(seq_data) + + dataloader = DataLoader(dataset=iterable_dataset, batch_size=5) + return dataloader + + # Set val_check_interval + trainer = pl.Trainer(val_check_interval=1000)