diff --git a/CHANGELOG.md b/CHANGELOG.md index 664477da03..0c85b5f79d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) - * Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) + * Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) - * Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) + * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) + * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) - Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index dc91a2f29d..d310d42f93 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -20,6 +20,7 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -98,7 +99,9 @@ class EvaluationLoop(DataLoaderLoop): """Performs evaluation on one single dataloader""" void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader_iter = enumerate(dataloader) + data_fetcher = DataFetcher() + data_fetcher.setup(dataloader) + dataloader_iter = enumerate(data_fetcher) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ec3e63f933..b03b15d820 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -86,7 +86,7 @@ class EvaluationEpochLoop(Loop): """ void(dl_max_batches, num_dataloaders) - batch_idx, batch = next(dataloader_iter) + batch_idx, (batch, _) = next(dataloader_iter) if batch is None: raise StopIteration diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py index 2ca10d73f8..af885efe28 100644 --- a/pytorch_lightning/profiler/base.py +++ b/pytorch_lightning/profiler/base.py @@ -17,7 +17,7 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Optional, TextIO, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -78,7 +78,7 @@ class BaseProfiler(AbstractProfiler): self._stage: Optional[str] = None @contextmanager - def profile(self, action_name: str) -> None: + def profile(self, action_name: str) -> Generator: """ Yields a context manager to encapsulate the scope of a profiled action. @@ -96,7 +96,7 @@ class BaseProfiler(AbstractProfiler): finally: self.stop(action_name) - def profile_iterable(self, iterable, action_name: str) -> None: + def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator: iterator = iter(iterable) while True: try: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 74b462e19a..5b967fed81 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,7 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -348,7 +348,7 @@ class CheckpointConnector: "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), } - if _fault_tolerant_enabled(): + if _fault_tolerant_training(): checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: @@ -451,7 +451,7 @@ class CheckpointConnector: def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ( [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] - if _fault_tolerant_enabled() + if _fault_tolerant_training() else [] ) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b93d24b7a4..629be97f29 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Callable, Iterable, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.trainer.supporters import prefetch_iterator from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -26,6 +26,7 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode + self.data_fetcher: Optional[DataFetcher] = None def on_trainer_init( self, @@ -59,10 +60,11 @@ class DataConnector: self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def get_profiled_train_dataloader(self, train_dataloader): - profiled_dl = self.trainer.profiler.profile_iterable( - enumerate(prefetch_iterator(train_dataloader)), "get_train_batch" - ) + def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: + self.data_fetcher = DataFetcher() + self.data_fetcher.setup(train_dataloader) + prefetcher_iter = iter(self.data_fetcher) + profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") return profiled_dl def prepare_data(self) -> None: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e63138a74d..2ea2a74c7a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -34,12 +34,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( _capture_metadata_collate, CaptureIterableDataset, + CaptureMapDataset, FastForwardSampler, ) from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import pl_worker_init_function @@ -168,7 +169,7 @@ class TrainerDataLoadingMixin(ABC): if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - if _fault_tolerant_enabled(): + if _fault_tolerant_training(): fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) @@ -180,7 +181,7 @@ class TrainerDataLoadingMixin(ABC): "drop_last": False, } - if _fault_tolerant_enabled(): + if _fault_tolerant_training(): fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) @@ -246,14 +247,20 @@ class TrainerDataLoadingMixin(ABC): f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) - # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset): - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) + if isinstance(dl_kwargs["dataset"], IterableDataset): + dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None - if isinstance(dl_kwargs["dataset"], IterableDataset): - del dl_kwargs["sampler"] - del dl_kwargs["batch_sampler"] + if _fault_tolerant_training(): + if isinstance(dl_kwargs["dataset"], IterableDataset): + # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. + dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) + elif len(dl_kwargs["dataset"]): + dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) + else: + raise MisconfigurationException( + "This shouldn't happen, please open an issue on Lightning Github repository." + ) return dl_kwargs @@ -308,7 +315,7 @@ class TrainerDataLoadingMixin(ABC): apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # add collate_fn to collect metadata for fault tolerant training - if _fault_tolerant_enabled(): + if _fault_tolerant_training(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9eaa2d28a4..21f64e0780 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -15,7 +15,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.utils.data import Dataset @@ -30,7 +30,7 @@ from pytorch_lightning.utilities.auto_restart import ( ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training class TensorRunningAccum: @@ -375,7 +375,7 @@ class CombinedLoader: num_batches_processed: The number of batches processed so far, needed because the individual dataloaders may have already prefetched more batches by the time a state dict is requested. """ - if not _fault_tolerant_enabled(): + if not _fault_tolerant_training(): return DataLoaderDict() state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) @@ -541,7 +541,7 @@ class CombinedLoaderIterator: def next_fn(iterator: Iterator): batch = next(iterator) - if not _fault_tolerant_enabled(): + if not _fault_tolerant_training(): return batch # when fault tolerant is enabled, the iterator will return # `FastForwardSampler` state_dict metadata @@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable new_data.append(x) return compute_func(new_data) - - -def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]: - """ - Returns an iterator that pre-fetches and caches the next item. - The values are passed through from the given iterable with an added boolean indicating if this is the last item. - See `https://stackoverflow.com/a/1630350 `_ - """ - it = iter(iterable) - - try: - # the iterator may be empty from the beginning - last = next(it) - except StopIteration: - return - - for val in it: - # yield last and has next - yield last, False - last = val - # yield last, no longer has next - yield last, True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc3b991053..8bc7370860 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,7 +77,7 @@ from pytorch_lightning.utilities import ( from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed @@ -1344,7 +1344,7 @@ class Trainer( ) def _on_exception(self): - if not _fault_tolerant_enabled(): + if not _fault_tolerant_training(): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c60c42bdc9..c9e378dbad 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tup from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -515,7 +516,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: def patch_dataloader_iterator( - dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0 + dataloader: DataLoader, + iterator: Iterator, + data_fecher: "pl.utilities.fetching.DataFetcher", + num_batches_fetched: int = 0, ) -> None: assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) @@ -554,7 +558,7 @@ def patch_dataloader_iterator( num_batches_fetched=num_batches_fetched, ) ] - prefetcher._store_dataloader_iter_state(it, state) + data_fecher._store_dataloader_iter_state(it, state) return batch return wrapper diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f053f13297..42de4f8571 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -29,10 +29,10 @@ from pytorch_lightning.utilities.auto_restart import ( patch_dataloader_iterator, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training -class AbstractFetcher(ABC): +class AbstractDataFetcher(ABC): """ This class is used to control batch fetching flow. @@ -61,13 +61,22 @@ class AbstractFetcher(ABC): self.reset() def setup(self, dataloader: DataLoader, **kwargs) -> None: - if not isinstance(dataloader, (DataLoader, CombinedLoader)): - raise MisconfigurationException( - "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." - ) + self._add_capture_metadata_collate(dataloader) self.dataloader = dataloader - if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): - _add_capture_metadata_collate(dataloader) + + @staticmethod + def _add_capture_metadata_collate(dataloader: Iterable) -> None: + if not isinstance(dataloader, (DataLoader, CombinedLoader)): + return + + if isinstance(dataloader, CombinedLoader): + dataloader = dataloader.loaders + + def add_capture_metadata_collate(dataloader: DataLoader): + if not isinstance(dataloader.collate_fn, partial): + _add_capture_metadata_collate(dataloader) + + apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate) def add_batch(self, batch) -> None: self.batches.append(batch) @@ -82,7 +91,7 @@ class AbstractFetcher(ABC): # cycle_iterator = iterator iterator = iterator._loader_iter - if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): + if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self) @@ -161,7 +170,7 @@ class AbstractFetcher(ABC): self.done: bool = False -class LightningDataFetcher(AbstractFetcher): +class DataFetcher(AbstractDataFetcher): """ This class is used to control batch fetching flow. diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 5e8cdd2137..f999847160 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -103,9 +103,6 @@ else: _IPU_AVAILABLE = False -def _fault_tolerant_enabled() -> bool: - """ - EXPERIMENTAL - the `reset` function from `_MultiProcessingDataLoaderIter` was introduced in PyTorch 1.7 but we need to mock it. - """ +# experimental feature within PyTorch Lightning. +def _fault_tolerant_training() -> bool: return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0)) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0906ed3820..314ed899c5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1111,6 +1111,7 @@ def test_current_score_when_nan(tmpdir, mode: str): model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo", mode=mode) trainer = Trainer( default_root_dir=tmpdir, + max_epochs=1, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], @@ -1133,6 +1134,7 @@ def test_hparams_type(tmpdir, hparams_type): model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo") trainer = Trainer( + max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0c90dee2e5..7c6a985d09 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -28,7 +28,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -384,7 +384,7 @@ def result_collection_reload(**kwargs): and final accumulation with Fault Tolerant Training is correct. """ - if not _fault_tolerant_enabled(): + if not _fault_tolerant_training(): pytest.skip("Fault tolerant not available") num_processes = kwargs.get("gpus", 1) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 45c6453688..65cbebc820 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -379,6 +379,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch pass ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") + assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index e8e5d0be10..4375bf7f25 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -29,7 +29,6 @@ from pytorch_lightning.trainer.supporters import ( CombinedLoader, CombinedLoaderIterator, CycleIterator, - prefetch_iterator, TensorRunningAccum, ) from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -80,28 +79,6 @@ def test_none_length_cycle_iterator(): assert item == 0 -def test_prefetch_iterator(): - """Test the prefetch_iterator with PyTorch IterableDataset.""" - - class IterDataset(IterableDataset): - def __iter__(self): - yield 1 - yield 2 - yield 3 - - dataset = IterDataset() - iterator = prefetch_iterator(dataset) - assert list(iterator) == [(1, False), (2, False), (3, True)] - - class EmptyIterDataset(IterableDataset): - def __iter__(self): - return iter([]) - - dataset = EmptyIterDataset() - iterator = prefetch_iterator(dataset) - assert list(iterator) == [] - - @pytest.mark.parametrize( ["dataset_1", "dataset_2"], [ diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 361600b9cd..e665fc79e4 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,7 +39,7 @@ from pytorch_lightning.utilities.auto_restart import ( ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -100,13 +100,6 @@ def _generate_state(base_seed, worker_id): return state -@RunIf(min_torch="1.7.0") -@pytest.mark.parametrize("env_setting,expected", [("0", False), ("1", True)]) -def test_fault_tolerant_enabled(env_setting, expected): - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": env_setting}): - assert _fault_tolerant_enabled() == expected - - def test_fast_forward_getattr(): dataset = range(15) sampler = SequentialSampler(dataset) @@ -647,10 +640,9 @@ def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@RunIf(max_torch="1.6") +@RunIf(max_torch="1.7") def test_fault_tolerant_not_supported(): - with pytest.raises(MisconfigurationException, match="Restart is only supported with torch >= 1.7.0."): - _fault_tolerant_enabled() + assert not _fault_tolerant_training() def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True): diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 752fafb27d..35b309549f 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -17,12 +17,12 @@ from torch.utils.data import DataLoader, IterableDataset from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import LightningDataFetcher +from pytorch_lightning.utilities.fetching import DataFetcher @pytest.mark.parametrize("use_combined_loader", [False, True]) def test_prefetch_iterator(use_combined_loader): - """Test the LightningDataFetcher with PyTorch IterableDataset.""" + """Test the DataFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): @@ -41,7 +41,7 @@ def test_prefetch_iterator(use_combined_loader): else: loader = DataLoader(IterDataset()) expected = [(1, False), (2, False), (3, True)] - iterator = LightningDataFetcher(prefetch_batches=prefetch_batches) + iterator = DataFetcher(prefetch_batches=prefetch_batches) prefetch_batches += 1 assert iterator.prefetch_batches == prefetch_batches iterator.setup(loader) @@ -66,21 +66,14 @@ def test_prefetch_iterator(use_combined_loader): return iter([]) dataloader = DataLoader(EmptyIterDataset()) - iterator = LightningDataFetcher() + iterator = DataFetcher() iterator.setup(dataloader) assert list(iterator) == [] def test_misconfiguration_error(): - fetcher = LightningDataFetcher() - with pytest.raises( - MisconfigurationException, - match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.", - ): - fetcher.setup(range(10)) - - fetcher = LightningDataFetcher() + fetcher = DataFetcher() with pytest.raises( MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ):