Update docs iterable datasets (#1281)
* Updated Sequencial Data docs * Sequntial Data section now contains info on using IterableDatasets * * Undid reformatting of bullet points * * added information about val_check_interval Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com>
This commit is contained in:
parent
ab09faa15e
commit
fb42872259
|
@ -42,4 +42,36 @@ Lightning can handle TBTT automatically via this flag.
|
||||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
||||||
|
|
||||||
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
|
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
|
||||||
a `hiddens` arg.
|
a `hiddens` arg.
|
||||||
|
|
||||||
|
Iterable Datasets
|
||||||
|
---------------------------------------
|
||||||
|
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
|
||||||
|
option when using sequential data.
|
||||||
|
|
||||||
|
.. note:: When using an IterableDataset you must set the val_check_interval to an int (specifying the number of training
|
||||||
|
batches to run before validation) when initializing the Trainer even when there is no validation logic in place.
|
||||||
|
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
|
||||||
|
the default validation interval.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# IterableDataset
|
||||||
|
class CustomDataset(IterableDataset):
|
||||||
|
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data_source
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.data_source)
|
||||||
|
|
||||||
|
# Setup DataLoader
|
||||||
|
def train_dataloader(self):
|
||||||
|
seq_data = ['A', 'long', 'time', 'ago', 'in', 'a', 'galaxy', 'far', 'far', 'away']
|
||||||
|
iterable_dataset = CustomDataset(seq_data)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset=iterable_dataset, batch_size=5)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
# Set val_check_interval
|
||||||
|
trainer = pl.Trainer(val_check_interval=1000)
|
||||||
|
|
Loading…
Reference in New Issue