From 71bfdc3c6019e5bab60143fd89c3e797e35cbf46 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 24 Jan 2024 17:07:32 +0100 Subject: [PATCH] Remove `__len__` from CombinedStreamingDataset (#19321) --- src/lightning/data/streaming/combined.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lightning/data/streaming/combined.py b/src/lightning/data/streaming/combined.py index c4453cf720..7b3373a3e8 100644 --- a/src/lightning/data/streaming/combined.py +++ b/src/lightning/data/streaming/combined.py @@ -27,8 +27,11 @@ class CombinedStreamingDataset(IterableDataset): """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. - Addtionally, the `CombinedStreamingDataset` keeps track of the number of - samples fetched to enable resumability of the datasets. + Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability + of the datasets. + + Note that due to the random sampling, the number of samples returned from the iterator is variable and a function + of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted. """ @@ -71,10 +74,6 @@ class CombinedStreamingDataset(IterableDataset): # Used to prevent returning num_samples_yielded when using PyTorch DataLoader self._use_streaming_dataloader = use_streaming_dataloader - def __len__(self) -> int: - assert self._weights - return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0])) - def __iter__(self) -> Iterator[Any]: assert self._weights