[Fault Tolerance] Don't check the len of a dataset, but its instance. (#10432)

This commit is contained in:
thomas chaton 2021-11-09 14:51:53 +00:00 committed by GitHub
parent edbf27430d
commit f1882aa69f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -37,7 +37,7 @@ from pytorch_lightning.utilities.auto_restart import (
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.data import get_len, has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@ -282,10 +282,11 @@ class TrainerDataLoadingMixin(ABC):
dl_kwargs["sampler"] = None
if _fault_tolerant_training():
if isinstance(dl_kwargs["dataset"], IterableDataset):
dataset = dl_kwargs["dataset"]
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif len(dl_kwargs["dataset"]):
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(