Do not prefetch when possible (#12101)
This commit is contained in:
parent
ed7ccca5df
commit
6309a59c3c
|
@ -393,6 +393,9 @@ option when using sequential data.
|
|||
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
|
||||
Here ``mode`` can be train/val/test/predict.
|
||||
|
||||
When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect
|
||||
when the training will stop and run validation if necessary.
|
||||
|
||||
.. testcode::
|
||||
|
||||
# IterableDataset
|
||||
|
|
|
@ -81,6 +81,13 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
raise RuntimeError("Dataloaders should be available.")
|
||||
return dataloaders
|
||||
|
||||
@property
|
||||
def prefetch_batches(self) -> int:
|
||||
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
|
||||
is_unsized = batches[self.current_dataloader_idx] == float("inf")
|
||||
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
|
||||
return 1 if is_unsized or inter_batch_parallelism else 0
|
||||
|
||||
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
|
||||
"""Connect the evaluation epoch loop with this loop."""
|
||||
self.epoch_loop = epoch_loop
|
||||
|
@ -121,7 +128,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
void(*args, **kwargs)
|
||||
|
||||
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
|
||||
self._data_fetcher = data_fetcher_cls()
|
||||
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
|
||||
|
||||
# hook
|
||||
self._on_evaluation_model_eval()
|
||||
|
|
|
@ -85,6 +85,8 @@ class EvaluationEpochLoop(Loop):
|
|||
self._reload_dataloader_state_dict(data_fetcher)
|
||||
# creates the iterator inside the fetcher but returns `self`
|
||||
self._data_fetcher = cast(AbstractDataFetcher, iter(data_fetcher))
|
||||
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
|
||||
data_fetcher.fetched += self.batch_progress.current.ready
|
||||
|
||||
def advance( # type: ignore[override]
|
||||
self,
|
||||
|
|
|
@ -142,7 +142,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
|
||||
self._reload_dataloader_state_dict(data_fetcher)
|
||||
iter(data_fetcher) # creates the iterator inside the fetcher
|
||||
_ = iter(data_fetcher) # creates the iterator inside the fetcher
|
||||
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
|
||||
data_fetcher.fetched += self.batch_progress.current.ready
|
||||
|
||||
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
|
||||
"""Runs a single training batch.
|
||||
|
|
|
@ -149,6 +149,12 @@ class FitLoop(Loop[None]):
|
|||
restarting &= finished_before_on_train_end
|
||||
Loop.restarting.fset(self, restarting) # call the parent setter
|
||||
|
||||
@property
|
||||
def prefetch_batches(self) -> int:
|
||||
is_unsized = self.trainer.num_training_batches == float("inf")
|
||||
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
|
||||
return 1 if is_unsized or inter_batch_parallelism else 0
|
||||
|
||||
@property
|
||||
def _skip_backward(self) -> bool:
|
||||
"""Determines whether the loop will skip backward during automatic optimization."""
|
||||
|
@ -213,8 +219,9 @@ class FitLoop(Loop[None]):
|
|||
"""Calls the ``on_train_start`` hook."""
|
||||
# reset train dataloader and val dataloader
|
||||
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
|
||||
|
||||
data_fetcher_cls = _select_data_fetcher(self.trainer)
|
||||
self._data_fetcher = data_fetcher_cls()
|
||||
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
|
||||
|
||||
self._is_fresh_start_epoch = True
|
||||
self._results.to(device=self.trainer.lightning_module.device)
|
||||
|
|
|
@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
|
|||
|
||||
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the length of Dataloader is 0, as it requires at least one batch
|
||||
"""
|
||||
|
||||
infinite dataloader."""
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
|
||||
rank_zero_warn(
|
||||
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
|
||||
)
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
|
@ -122,30 +118,27 @@ def has_len_all_ranks(
|
|||
model: Union["pl.LightningModule", "pl.LightningDataModule"],
|
||||
) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the length of Dataloader is 0, as it requires at least one batch
|
||||
"""
|
||||
infinite dataloader."""
|
||||
try:
|
||||
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
|
||||
local_length = len(dataloader)
|
||||
total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")
|
||||
|
||||
if total_length == 0:
|
||||
raise MisconfigurationException(
|
||||
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
|
||||
rank_zero_warn(
|
||||
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
|
||||
" Please make sure this was your intention."
|
||||
)
|
||||
if total_length > 0 and local_length == 0:
|
||||
if model.allow_zero_length_dataloader_with_multiple_devices:
|
||||
rank_zero_warn(
|
||||
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
|
||||
" Please be cautious of uneven batch length."
|
||||
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero"
|
||||
" length. Please be cautious of uneven batch length."
|
||||
)
|
||||
has_len = False
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
|
||||
f"`{dataloader.__class__.__name__}` within local rank has zero length."
|
||||
" Please make sure that it returns at least 1 batch."
|
||||
)
|
||||
else:
|
||||
has_len = True
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Iterator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Sized, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
@ -30,6 +30,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import has_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
@ -79,6 +80,8 @@ class AbstractDataFetcher(ABC):
|
|||
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
|
||||
self._add_capture_metadata_collate(dataloader)
|
||||
self._dataloader = dataloader
|
||||
_patch_dataloader_get_iterators()
|
||||
self._attach_data_fetcher()
|
||||
|
||||
@property
|
||||
def dataloader(self) -> Iterable:
|
||||
|
@ -172,8 +175,6 @@ class AbstractDataFetcher(ABC):
|
|||
|
||||
def __iter__(self) -> "AbstractDataFetcher":
|
||||
self.reset()
|
||||
self._attach_data_fetcher()
|
||||
_patch_dataloader_get_iterators()
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
self._apply_patch()
|
||||
self.prefetching()
|
||||
|
@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher):
|
|||
|
||||
Args:
|
||||
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
|
||||
whether a batch is the last one (available with :attr:`self.done`).
|
||||
whether a batch is the last one (available with :attr:`self.done`) under any training setup.
|
||||
store_on_device: Whether to store the pre-fetched batches on device.
|
||||
"""
|
||||
|
||||
|
@ -214,11 +215,13 @@ class DataFetcher(AbstractDataFetcher):
|
|||
self.store_on_device = store_on_device
|
||||
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
|
||||
self.batches: List[Any] = []
|
||||
self._has_len = False
|
||||
|
||||
def setup( # type: ignore[override]
|
||||
self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None
|
||||
) -> None:
|
||||
super().setup(dataloader)
|
||||
self._has_len = has_len(dataloader)
|
||||
if batch_to_device is not None:
|
||||
self.batch_to_device = batch_to_device
|
||||
|
||||
|
@ -233,6 +236,9 @@ class DataFetcher(AbstractDataFetcher):
|
|||
try:
|
||||
self._fetch_next_batch(iterator)
|
||||
except StopIteration:
|
||||
# this would only happen when prefetch_batches > the number of batches available and makes
|
||||
# `fetching_function` jump directly to the empty iterator case without trying to fetch again
|
||||
self.done = True
|
||||
break
|
||||
|
||||
def fetching_function(self) -> Any:
|
||||
|
@ -266,6 +272,11 @@ class DataFetcher(AbstractDataFetcher):
|
|||
start_output = self.on_fetch_start()
|
||||
batch = next(iterator)
|
||||
self.fetched += 1
|
||||
if not self.prefetch_batches and self._has_len:
|
||||
# when we don't prefetch but the dataloader is sized, we use the length for `done`
|
||||
dataloader = self.dataloader
|
||||
assert isinstance(dataloader, Sized) # `_has_len` is True
|
||||
self.done = self.fetched >= len(dataloader)
|
||||
self.on_fetch_end(batch, start_output)
|
||||
|
||||
def move_to_device(self, batch: Any) -> Any:
|
||||
|
@ -360,7 +371,8 @@ class DataLoaderIterDataFetcher(AbstractDataFetcher):
|
|||
...
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, prefetch_batches: int = 0) -> None:
|
||||
# prefetch batches is not used for this class
|
||||
super().__init__()
|
||||
self.store_on_device = False
|
||||
|
||||
|
|
|
@ -648,16 +648,12 @@ def test_loop_state_on_complete_run(n_optimizers, tmpdir):
|
|||
"ready": n_epochs,
|
||||
"started": n_epochs,
|
||||
"processed": n_epochs,
|
||||
# TODO: the following "-1" offset will be fixed by
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
|
||||
"completed": n_epochs - 1,
|
||||
},
|
||||
"current": {
|
||||
"ready": n_epochs,
|
||||
"started": n_epochs,
|
||||
"processed": n_epochs,
|
||||
# TODO: the following "-1" offset will be fixed by
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
|
||||
"completed": n_epochs - 1,
|
||||
},
|
||||
},
|
||||
|
@ -956,8 +952,6 @@ def test_fit_can_fail_during_validation(train_datasets, val_datasets, val_check_
|
|||
# totals are increased by 1 (the failed batch which never completed)
|
||||
expected = state_dict.copy()
|
||||
|
||||
# TODO: `is_last_batch` is not correct on reload, the next line should not be necessary
|
||||
expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0
|
||||
assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"]
|
||||
|
||||
val_dl_progress = "epoch_loop.val_loop.dataloader_progress"
|
||||
|
|
|
@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
|
|||
assert len(trainer.test_dataloaders) == 1
|
||||
|
||||
|
||||
def test_error_on_zero_len_dataloader(tmpdir):
|
||||
"""Test that error is raised if a zero-length dataloader is defined."""
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(32, 0))
|
||||
|
||||
model = CustomBoringModel()
|
||||
def test_warning_on_zero_len_dataloader(tmpdir):
|
||||
"""Test that a warning is raised if a zero-length dataloader is defined."""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=1,
|
||||
)
|
||||
with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"):
|
||||
trainer.fit(model)
|
||||
dataloader = DataLoader(RandomDataset(32, 0))
|
||||
with pytest.warns(UserWarning, match="returned 0 length"):
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
|
|
|
@ -1452,7 +1452,7 @@ class RandomFaultTolerantDataset(RandomGetItemDataset):
|
|||
|
||||
|
||||
class RandomFaultTolerantSampler(RandomSampler):
|
||||
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
|
||||
def __init__(self, *args, seed: int = 0, **kwargs):
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
super().__init__(*args, generator=generator, **kwargs)
|
||||
self.counter = 0
|
||||
|
@ -1558,7 +1558,7 @@ def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_d
|
|||
seed_everything(42)
|
||||
model = TestModel(should_fail=True)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
|
||||
with suppress(CustomException):
|
||||
with pytest.raises(CustomException):
|
||||
trainer.fit(model)
|
||||
trainer.train_dataloader = None
|
||||
failed_batches = model.batches
|
||||
|
|
|
@ -93,7 +93,7 @@ def test_has_iterable_dataset():
|
|||
def test_has_len():
|
||||
assert has_len(DataLoader(RandomDataset(1, 1)))
|
||||
|
||||
with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
|
||||
with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
|
||||
assert has_len(DataLoader(RandomDataset(0, 0)))
|
||||
|
||||
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
|
||||
|
@ -112,8 +112,8 @@ def test_has_len_all_rank():
|
|||
trainer = Trainer(fast_dev_run=True)
|
||||
model = BoringModel()
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
|
||||
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
|
||||
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
|
||||
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
|
||||
|
||||
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model)
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import tensor
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||
|
@ -30,57 +29,74 @@ from tests.helpers import BoringModel, RandomDataset
|
|||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
class IterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
|
||||
class SizedDataset(Dataset):
|
||||
def __len__(self):
|
||||
return 3
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return idx + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_combined_loader", [False, True])
|
||||
def test_prefetch_iterator(use_combined_loader):
|
||||
"""Test the DataFetcher with PyTorch IterableDataset."""
|
||||
@pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset])
|
||||
@pytest.mark.parametrize("prefetch_batches", list(range(5)))
|
||||
def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches):
|
||||
fetcher = DataFetcher(prefetch_batches=prefetch_batches)
|
||||
assert fetcher.prefetch_batches == prefetch_batches
|
||||
|
||||
class IterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
if use_combined_loader:
|
||||
loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())])
|
||||
else:
|
||||
loader = DataLoader(dataset_cls())
|
||||
fetcher.setup(loader)
|
||||
|
||||
for prefetch_batches in range(5):
|
||||
iterator = DataFetcher(prefetch_batches=prefetch_batches)
|
||||
assert iterator.prefetch_batches == prefetch_batches
|
||||
def generate():
|
||||
generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher]
|
||||
assert fetcher.fetched == 3
|
||||
assert fetcher.done
|
||||
return generated
|
||||
|
||||
if use_combined_loader:
|
||||
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
|
||||
else:
|
||||
loader = DataLoader(IterDataset())
|
||||
iterator.setup(loader)
|
||||
# we can only know the last batch with sized iterables or when we prefetch
|
||||
is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset]
|
||||
fetched = list(range(prefetch_batches + 1, 4))
|
||||
fetched += [3] * (3 - len(fetched))
|
||||
batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
|
||||
expected = list(zip(fetched, batches, is_last_batch))
|
||||
assert len(expected) == 3
|
||||
|
||||
def generate():
|
||||
generated = [
|
||||
(iterator.fetched, data, iterator.done) for i, data in enumerate(iterator, prefetch_batches + 1)
|
||||
]
|
||||
assert iterator.fetched == 3
|
||||
assert iterator.done
|
||||
return generated
|
||||
assert generate() == expected
|
||||
# validate reset works properly.
|
||||
assert generate() == expected
|
||||
assert fetcher.fetched == 3
|
||||
|
||||
is_last_batch = [False, False, prefetch_batches > 0]
|
||||
fetched = list(range(prefetch_batches + 1, 4))
|
||||
fetched += [3] * (3 - len(fetched))
|
||||
if use_combined_loader:
|
||||
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
|
||||
else:
|
||||
batches = [1, 2, 3]
|
||||
expected = list(zip(fetched, batches, is_last_batch))
|
||||
assert len(expected) == 3
|
||||
|
||||
assert generate() == expected
|
||||
# validate reset works properly.
|
||||
assert generate() == expected
|
||||
assert iterator.fetched == 3
|
||||
class EmptyIterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
return iter([])
|
||||
|
||||
class EmptyIterDataset(IterableDataset):
|
||||
def __iter__(self):
|
||||
return iter([])
|
||||
|
||||
loader = DataLoader(EmptyIterDataset())
|
||||
iterator = DataFetcher()
|
||||
iterator.setup(loader)
|
||||
assert not list(iterator)
|
||||
class EmptySizedDataset(Dataset):
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset])
|
||||
@pytest.mark.parametrize("prefetch_batches", list(range(2)))
|
||||
def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):
|
||||
loader = DataLoader(dataset_cls())
|
||||
fetcher = DataFetcher(prefetch_batches=prefetch_batches)
|
||||
fetcher.setup(loader)
|
||||
|
||||
assert not fetcher.done
|
||||
assert not list(fetcher)
|
||||
assert fetcher.done
|
||||
|
||||
|
||||
def test_misconfiguration_error():
|
||||
|
@ -188,7 +204,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
def on_train_epoch_end(self, trainer, lightning_module):
|
||||
fetcher = trainer.fit_loop._data_fetcher
|
||||
assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher)
|
||||
assert fetcher.prefetch_batches == 1
|
||||
assert fetcher.prefetch_batches == int(self._check_inter_batch)
|
||||
|
||||
trainer_kwargs = dict(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -269,14 +285,19 @@ def test_fetching_dataloader_iter_opt(automatic_optimization, tmpdir):
|
|||
@RunIf(min_torch="1.8.0")
|
||||
def test_fetching_dataloader_iter_running_stages(fn, tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def validation_step(self, dataloader_iter, batch_idx):
|
||||
assert isinstance(self.trainer.validate_loop._data_fetcher, DataLoaderIterDataFetcher)
|
||||
def fetch(self, data_fetcher, dataloader_iter, batch_idx):
|
||||
assert isinstance(data_fetcher, DataLoaderIterDataFetcher)
|
||||
assert data_fetcher.fetched == batch_idx
|
||||
batch = next(dataloader_iter)
|
||||
assert data_fetcher.fetched == batch_idx + 1
|
||||
return batch
|
||||
|
||||
def validation_step(self, dataloader_iter, batch_idx):
|
||||
batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def test_step(self, dataloader_iter, batch_idx):
|
||||
assert isinstance(self.trainer.test_loop._data_fetcher, DataLoaderIterDataFetcher)
|
||||
batch = next(dataloader_iter)
|
||||
batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx)
|
||||
return super().test_step(batch, batch_idx)
|
||||
|
||||
model = TestModel()
|
||||
|
|
Loading…
Reference in New Issue