diff --git a/src/lightning/data/streaming/combined.py b/src/lightning/data/streaming/combined.py index c4453cf720..7b3373a3e8 100644 --- a/src/lightning/data/streaming/combined.py +++ b/src/lightning/data/streaming/combined.py @@ -27,8 +27,11 @@ class CombinedStreamingDataset(IterableDataset): """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. - Addtionally, the `CombinedStreamingDataset` keeps track of the number of - samples fetched to enable resumability of the datasets. + Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability + of the datasets. + + Note that due to the random sampling, the number of samples returned from the iterator is variable and a function + of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted. """ @@ -71,10 +74,6 @@ class CombinedStreamingDataset(IterableDataset): # Used to prevent returning num_samples_yielded when using PyTorch DataLoader self._use_streaming_dataloader = use_streaming_dataloader - def __len__(self) -> int: - assert self._weights - return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0])) - def __iter__(self) -> Iterator[Any]: assert self._weights