3/n integrate new LightningDataFetcher into loop (#8953)
Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
77bc5d4004
commit
522df2b89b
|
@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fault-tolerant training:
|
||||
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
|
||||
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
|
||||
* Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
|
||||
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
|
||||
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
|
||||
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
|
||||
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
|
||||
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
|
||||
|
||||
- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
|
@ -98,7 +99,9 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
"""Performs evaluation on one single dataloader"""
|
||||
void(*args, **kwargs)
|
||||
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
|
||||
dataloader_iter = enumerate(dataloader)
|
||||
data_fetcher = DataFetcher()
|
||||
data_fetcher.setup(dataloader)
|
||||
dataloader_iter = enumerate(data_fetcher)
|
||||
dl_max_batches = self._max_batches[self.current_dataloader_idx]
|
||||
|
||||
dl_outputs = self.epoch_loop.run(
|
||||
|
|
|
@ -86,7 +86,7 @@ class EvaluationEpochLoop(Loop):
|
|||
"""
|
||||
void(dl_max_batches, num_dataloaders)
|
||||
|
||||
batch_idx, batch = next(dataloader_iter)
|
||||
batch_idx, (batch, _) = next(dataloader_iter)
|
||||
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, TextIO, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
@ -78,7 +78,7 @@ class BaseProfiler(AbstractProfiler):
|
|||
self._stage: Optional[str] = None
|
||||
|
||||
@contextmanager
|
||||
def profile(self, action_name: str) -> None:
|
||||
def profile(self, action_name: str) -> Generator:
|
||||
"""
|
||||
Yields a context manager to encapsulate the scope of a profiled action.
|
||||
|
||||
|
@ -96,7 +96,7 @@ class BaseProfiler(AbstractProfiler):
|
|||
finally:
|
||||
self.stop(action_name)
|
||||
|
||||
def profile_iterable(self, iterable, action_name: str) -> None:
|
||||
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
|
|
|
@ -24,7 +24,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
|
@ -348,7 +348,7 @@ class CheckpointConnector:
|
|||
"pytorch-lightning_version": pl.__version__,
|
||||
"state_dict": self._get_lightning_module_state_dict(),
|
||||
}
|
||||
if _fault_tolerant_enabled():
|
||||
if _fault_tolerant_training():
|
||||
checkpoint["loops"] = self._get_loops_state_dict()
|
||||
|
||||
if not weights_only:
|
||||
|
@ -451,7 +451,7 @@ class CheckpointConnector:
|
|||
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
metrics = (
|
||||
[m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
|
||||
if _fault_tolerant_enabled()
|
||||
if _fault_tolerant_training()
|
||||
else []
|
||||
)
|
||||
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.supporters import prefetch_iterator
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
|
@ -26,6 +26,7 @@ class DataConnector:
|
|||
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
|
||||
self.trainer = trainer
|
||||
self.multiple_trainloader_mode = multiple_trainloader_mode
|
||||
self.data_fetcher: Optional[DataFetcher] = None
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
|
@ -59,10 +60,11 @@ class DataConnector:
|
|||
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
|
||||
self.trainer._is_data_prepared = False
|
||||
|
||||
def get_profiled_train_dataloader(self, train_dataloader):
|
||||
profiled_dl = self.trainer.profiler.profile_iterable(
|
||||
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
|
||||
)
|
||||
def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
|
||||
self.data_fetcher = DataFetcher()
|
||||
self.data_fetcher.setup(train_dataloader)
|
||||
prefetcher_iter = iter(self.data_fetcher)
|
||||
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
|
||||
return profiled_dl
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
|
|
|
@ -34,12 +34,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_capture_metadata_collate,
|
||||
CaptureIterableDataset,
|
||||
CaptureMapDataset,
|
||||
FastForwardSampler,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import pl_worker_init_function
|
||||
|
||||
|
@ -168,7 +169,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if is_predicting:
|
||||
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
|
||||
|
||||
if _fault_tolerant_enabled():
|
||||
if _fault_tolerant_training():
|
||||
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
|
||||
fast_forward_sampler.setup(dataloader_batch_size=1)
|
||||
|
||||
|
@ -180,7 +181,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
"drop_last": False,
|
||||
}
|
||||
|
||||
if _fault_tolerant_enabled():
|
||||
if _fault_tolerant_training():
|
||||
fast_forward_sampler = sampler = FastForwardSampler(sampler)
|
||||
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)
|
||||
|
||||
|
@ -246,14 +247,20 @@ class TrainerDataLoadingMixin(ABC):
|
|||
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
|
||||
)
|
||||
|
||||
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
||||
if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
|
||||
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
dl_kwargs["batch_sampler"] = None
|
||||
dl_kwargs["sampler"] = None
|
||||
|
||||
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
del dl_kwargs["sampler"]
|
||||
del dl_kwargs["batch_sampler"]
|
||||
if _fault_tolerant_training():
|
||||
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
||||
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
|
||||
elif len(dl_kwargs["dataset"]):
|
||||
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"This shouldn't happen, please open an issue on Lightning Github repository."
|
||||
)
|
||||
|
||||
return dl_kwargs
|
||||
|
||||
|
@ -308,7 +315,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
|
||||
|
||||
# add collate_fn to collect metadata for fault tolerant training
|
||||
if _fault_tolerant_enabled():
|
||||
if _fault_tolerant_training():
|
||||
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
|
||||
|
||||
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -30,7 +30,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
class TensorRunningAccum:
|
||||
|
@ -375,7 +375,7 @@ class CombinedLoader:
|
|||
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
|
||||
may have already prefetched more batches by the time a state dict is requested.
|
||||
"""
|
||||
if not _fault_tolerant_enabled():
|
||||
if not _fault_tolerant_training():
|
||||
return DataLoaderDict()
|
||||
|
||||
state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
|
||||
|
@ -541,7 +541,7 @@ class CombinedLoaderIterator:
|
|||
|
||||
def next_fn(iterator: Iterator):
|
||||
batch = next(iterator)
|
||||
if not _fault_tolerant_enabled():
|
||||
if not _fault_tolerant_training():
|
||||
return batch
|
||||
# when fault tolerant is enabled, the iterator will return
|
||||
# `FastForwardSampler` state_dict metadata
|
||||
|
@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
|
|||
new_data.append(x)
|
||||
|
||||
return compute_func(new_data)
|
||||
|
||||
|
||||
def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
|
||||
"""
|
||||
Returns an iterator that pre-fetches and caches the next item.
|
||||
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
|
||||
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
|
||||
"""
|
||||
it = iter(iterable)
|
||||
|
||||
try:
|
||||
# the iterator may be empty from the beginning
|
||||
last = next(it)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
for val in it:
|
||||
# yield last and has next
|
||||
yield last, False
|
||||
last = val
|
||||
# yield last, no longer has next
|
||||
yield last, True
|
||||
|
|
|
@ -77,7 +77,7 @@ from pytorch_lightning.utilities import (
|
|||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
@ -1344,7 +1344,7 @@ class Trainer(
|
|||
)
|
||||
|
||||
def _on_exception(self):
|
||||
if not _fault_tolerant_enabled():
|
||||
if not _fault_tolerant_training():
|
||||
return
|
||||
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
|
||||
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tup
|
|||
from torch.utils.data import Dataset, get_worker_info, Sampler
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -515,7 +516,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
|
|||
|
||||
|
||||
def patch_dataloader_iterator(
|
||||
dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0
|
||||
dataloader: DataLoader,
|
||||
iterator: Iterator,
|
||||
data_fecher: "pl.utilities.fetching.DataFetcher",
|
||||
num_batches_fetched: int = 0,
|
||||
) -> None:
|
||||
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))
|
||||
|
||||
|
@ -554,7 +558,7 @@ def patch_dataloader_iterator(
|
|||
num_batches_fetched=num_batches_fetched,
|
||||
)
|
||||
]
|
||||
prefetcher._store_dataloader_iter_state(it, state)
|
||||
data_fecher._store_dataloader_iter_state(it, state)
|
||||
return batch
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -29,10 +29,10 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
class AbstractFetcher(ABC):
|
||||
class AbstractDataFetcher(ABC):
|
||||
|
||||
"""
|
||||
This class is used to control batch fetching flow.
|
||||
|
@ -61,13 +61,22 @@ class AbstractFetcher(ABC):
|
|||
self.reset()
|
||||
|
||||
def setup(self, dataloader: DataLoader, **kwargs) -> None:
|
||||
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
|
||||
raise MisconfigurationException(
|
||||
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
|
||||
)
|
||||
self._add_capture_metadata_collate(dataloader)
|
||||
self.dataloader = dataloader
|
||||
if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial):
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
@staticmethod
|
||||
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
|
||||
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
|
||||
return
|
||||
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloader = dataloader.loaders
|
||||
|
||||
def add_capture_metadata_collate(dataloader: DataLoader):
|
||||
if not isinstance(dataloader.collate_fn, partial):
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)
|
||||
|
||||
def add_batch(self, batch) -> None:
|
||||
self.batches.append(batch)
|
||||
|
@ -82,7 +91,7 @@ class AbstractFetcher(ABC):
|
|||
# cycle_iterator = iterator
|
||||
iterator = iterator._loader_iter
|
||||
|
||||
if isinstance(loader, DataLoader) and _fault_tolerant_enabled():
|
||||
if isinstance(loader, DataLoader) and _fault_tolerant_training():
|
||||
loader._lightning_fetcher = self
|
||||
patch_dataloader_iterator(loader, iterator, self)
|
||||
|
||||
|
@ -161,7 +170,7 @@ class AbstractFetcher(ABC):
|
|||
self.done: bool = False
|
||||
|
||||
|
||||
class LightningDataFetcher(AbstractFetcher):
|
||||
class DataFetcher(AbstractDataFetcher):
|
||||
|
||||
"""
|
||||
This class is used to control batch fetching flow.
|
||||
|
|
|
@ -103,9 +103,6 @@ else:
|
|||
_IPU_AVAILABLE = False
|
||||
|
||||
|
||||
def _fault_tolerant_enabled() -> bool:
|
||||
"""
|
||||
EXPERIMENTAL
|
||||
the `reset` function from `_MultiProcessingDataLoaderIter` was introduced in PyTorch 1.7 but we need to mock it.
|
||||
"""
|
||||
# experimental feature within PyTorch Lightning.
|
||||
def _fault_tolerant_training() -> bool:
|
||||
return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))
|
||||
|
|
|
@ -1111,6 +1111,7 @@ def test_current_score_when_nan(tmpdir, mode: str):
|
|||
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo", mode=mode)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
callbacks=[model_checkpoint],
|
||||
|
@ -1133,6 +1134,7 @@ def test_hparams_type(tmpdir, hparams_type):
|
|||
|
||||
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo")
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
|
|
|
@ -28,7 +28,7 @@ import tests.helpers.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -384,7 +384,7 @@ def result_collection_reload(**kwargs):
|
|||
and final accumulation with Fault Tolerant Training is correct.
|
||||
"""
|
||||
|
||||
if not _fault_tolerant_enabled():
|
||||
if not _fault_tolerant_training():
|
||||
pytest.skip("Fault tolerant not available")
|
||||
|
||||
num_processes = kwargs.get("gpus", 1)
|
||||
|
|
|
@ -379,6 +379,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
pass
|
||||
|
||||
ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
|
||||
assert os.path.exists(ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
|
||||
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress
|
||||
|
|
|
@ -29,7 +29,6 @@ from pytorch_lightning.trainer.supporters import (
|
|||
CombinedLoader,
|
||||
CombinedLoaderIterator,
|
||||
CycleIterator,
|
||||
prefetch_iterator,
|
||||
TensorRunningAccum,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -80,28 +79,6 @@ def test_none_length_cycle_iterator():
|
|||
assert item == 0
|
||||
|
||||
|
||||
def test_prefetch_iterator():
|
||||
"""Test the prefetch_iterator with PyTorch IterableDataset."""
|
||||
|
||||
class IterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
dataset = IterDataset()
|
||||
iterator = prefetch_iterator(dataset)
|
||||
assert list(iterator) == [(1, False), (2, False), (3, True)]
|
||||
|
||||
class EmptyIterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
return iter([])
|
||||
|
||||
dataset = EmptyIterDataset()
|
||||
iterator = prefetch_iterator(dataset)
|
||||
assert list(iterator) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["dataset_1", "dataset_2"],
|
||||
[
|
||||
|
|
|
@ -39,7 +39,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
)
|
||||
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -100,13 +100,6 @@ def _generate_state(base_seed, worker_id):
|
|||
return state
|
||||
|
||||
|
||||
@RunIf(min_torch="1.7.0")
|
||||
@pytest.mark.parametrize("env_setting,expected", [("0", False), ("1", True)])
|
||||
def test_fault_tolerant_enabled(env_setting, expected):
|
||||
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": env_setting}):
|
||||
assert _fault_tolerant_enabled() == expected
|
||||
|
||||
|
||||
def test_fast_forward_getattr():
|
||||
dataset = range(15)
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
@ -647,10 +640,9 @@ def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset():
|
|||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@RunIf(max_torch="1.6")
|
||||
@RunIf(max_torch="1.7")
|
||||
def test_fault_tolerant_not_supported():
|
||||
with pytest.raises(MisconfigurationException, match="Restart is only supported with torch >= 1.7.0."):
|
||||
_fault_tolerant_enabled()
|
||||
assert not _fault_tolerant_training()
|
||||
|
||||
|
||||
def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True):
|
||||
|
|
|
@ -17,12 +17,12 @@ from torch.utils.data import DataLoader, IterableDataset
|
|||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import LightningDataFetcher
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_combined_loader", [False, True])
|
||||
def test_prefetch_iterator(use_combined_loader):
|
||||
"""Test the LightningDataFetcher with PyTorch IterableDataset."""
|
||||
"""Test the DataFetcher with PyTorch IterableDataset."""
|
||||
|
||||
class IterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
|
@ -41,7 +41,7 @@ def test_prefetch_iterator(use_combined_loader):
|
|||
else:
|
||||
loader = DataLoader(IterDataset())
|
||||
expected = [(1, False), (2, False), (3, True)]
|
||||
iterator = LightningDataFetcher(prefetch_batches=prefetch_batches)
|
||||
iterator = DataFetcher(prefetch_batches=prefetch_batches)
|
||||
prefetch_batches += 1
|
||||
assert iterator.prefetch_batches == prefetch_batches
|
||||
iterator.setup(loader)
|
||||
|
@ -66,21 +66,14 @@ def test_prefetch_iterator(use_combined_loader):
|
|||
return iter([])
|
||||
|
||||
dataloader = DataLoader(EmptyIterDataset())
|
||||
iterator = LightningDataFetcher()
|
||||
iterator = DataFetcher()
|
||||
iterator.setup(dataloader)
|
||||
assert list(iterator) == []
|
||||
|
||||
|
||||
def test_misconfiguration_error():
|
||||
|
||||
fetcher = LightningDataFetcher()
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.",
|
||||
):
|
||||
fetcher.setup(range(10))
|
||||
|
||||
fetcher = LightningDataFetcher()
|
||||
fetcher = DataFetcher()
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue