Remove `__len__` from CombinedStreamingDataset (#19321)

This commit is contained in:
awaelchli 2024-01-24 17:07:32 +01:00 committed by GitHub
parent b446b08be5
commit 71bfdc3c60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 6 deletions

View File

@ -27,8 +27,11 @@ class CombinedStreamingDataset(IterableDataset):
"""The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of
your choice.
Addtionally, the `CombinedStreamingDataset` keeps track of the number of
samples fetched to enable resumability of the datasets.
Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability
of the datasets.
Note that due to the random sampling, the number of samples returned from the iterator is variable and a function
of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted.
"""
@ -71,10 +74,6 @@ class CombinedStreamingDataset(IterableDataset):
# Used to prevent returning num_samples_yielded when using PyTorch DataLoader
self._use_streaming_dataloader = use_streaming_dataloader
def __len__(self) -> int:
assert self._weights
return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0]))
def __iter__(self) -> Iterator[Any]:
assert self._weights