del iterator on_run_end() (#9915)
This commit is contained in:
parent
e4eb61d812
commit
a967b6eba0
|
@ -595,7 +595,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))
|
||||
|
||||
|
||||
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
|
||||
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915))
|
||||
|
||||
|
||||
- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
|
||||
|
|
|
@ -101,7 +101,9 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
dataloader_idx: int = self.current_dataloader_idx
|
||||
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
|
||||
dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
|
||||
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
|
||||
dataloader, dataloader_idx=dataloader_idx
|
||||
)
|
||||
dl_max_batches = self._max_batches[dataloader_idx]
|
||||
|
||||
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
|
||||
|
|
|
@ -306,8 +306,6 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
if self._num_ready_batches_reached():
|
||||
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)
|
||||
|
||||
self._dataloader_iter = None
|
||||
|
||||
# if fault tolerant is enabled and process has been notified, exit.
|
||||
self.trainer._exit_gracefully_on_signal()
|
||||
|
||||
|
|
|
@ -209,7 +209,9 @@ class FitLoop(Loop):
|
|||
self.trainer.reset_train_dataloader(model)
|
||||
self._is_fresh_start_epoch = False
|
||||
|
||||
if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)):
|
||||
if self.trainer.train_dataloader is not None and callable(
|
||||
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
|
||||
):
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
|
||||
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
|
@ -491,6 +491,19 @@ class CombinedLoader:
|
|||
def __len__(self) -> int:
|
||||
return self._calc_num_batches(self.loaders)
|
||||
|
||||
@staticmethod
|
||||
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
|
||||
if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
|
||||
dataloader._iterator._shutdown_workers()
|
||||
dataloader._iterator = None
|
||||
|
||||
def reset(self):
|
||||
if self._iterator:
|
||||
self._iterator._loader_iters = None
|
||||
if self.loaders is not None:
|
||||
apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
|
||||
self._iterator = None
|
||||
|
||||
|
||||
class CombinedLoaderIterator:
|
||||
"""Custom Iterator returning data from multple loaders, and allows sampling in parallel."""
|
||||
|
|
|
@ -204,9 +204,13 @@ class AbstractDataFetcher(ABC):
|
|||
|
||||
def reset(self) -> None:
|
||||
self.batches: List = []
|
||||
self.dataloader: Optional[Iterable]
|
||||
self.fetched: int = 0
|
||||
self.done: bool = False
|
||||
if isinstance(self.dataloader, CombinedLoader):
|
||||
self.dataloader.reset()
|
||||
if isinstance(self.dataloader, DataLoader):
|
||||
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
|
||||
self.dataloader_iter = None
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.reset()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loops import EvaluationEpochLoop
|
||||
|
|
|
@ -20,7 +20,7 @@ from unittest.mock import ANY
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader
|
||||
|
||||
from pl_examples.bug_report_model import RandomDataset
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
|
@ -909,3 +909,38 @@ def test_fit_can_fail_during_validation(train_datasets, val_datasets, val_check_
|
|||
expected[val_batch_progress]["total"]["ready"] += 1
|
||||
expected[val_batch_progress]["total"]["started"] += 1
|
||||
assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress]
|
||||
|
||||
|
||||
@RunIf(min_torch="1.8.0")
|
||||
@pytest.mark.parametrize("persistent_workers", (True, False))
|
||||
def test_workers_are_shutdown(tmpdir, persistent_workers):
|
||||
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
|
||||
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance
|
||||
|
||||
class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
|
||||
def __init__(self, *args, dataloader: DataLoader, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataloader = dataloader
|
||||
|
||||
def _shutdown_workers(self):
|
||||
setattr(self.dataloader, "has_shutdown_workers", True)
|
||||
super()._shutdown_workers()
|
||||
|
||||
class TestDataLoader(DataLoader):
|
||||
def _get_iterator(self):
|
||||
if self.num_workers == 0:
|
||||
return super()._get_iterator()
|
||||
else:
|
||||
self.check_worker_number_rationality()
|
||||
return _TestMultiProcessingDataLoaderIter(self, dataloader=self)
|
||||
|
||||
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
|
||||
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2)
|
||||
trainer.fit(model, train_dataloader, val_dataloader)
|
||||
assert train_dataloader.has_shutdown_workers
|
||||
assert val_dataloader.has_shutdown_workers
|
||||
assert train_dataloader._iterator is None
|
||||
assert val_dataloader._iterator is None
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
|
Loading…
Reference in New Issue