[Fault Tolerance] Don't check the len of a dataset, but its instance. (#10432)
This commit is contained in:
parent
edbf27430d
commit
f1882aa69f
|
@ -37,7 +37,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
CaptureMapDataset,
|
||||
FastForwardSampler,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
|
||||
from pytorch_lightning.utilities.data import get_len, has_iterable_dataset, has_len_all_ranks
|
||||
from pytorch_lightning.utilities.enums import DistributedType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
@ -282,10 +282,11 @@ class TrainerDataLoadingMixin(ABC):
|
|||
dl_kwargs["sampler"] = None
|
||||
|
||||
if _fault_tolerant_training():
|
||||
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
dataset = dl_kwargs["dataset"]
|
||||
if isinstance(dataset, IterableDataset):
|
||||
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
||||
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
|
||||
elif len(dl_kwargs["dataset"]):
|
||||
elif get_len(dataset) != float("inf"):
|
||||
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
|
|
Loading…
Reference in New Issue