From a967b6eba0556943ad1d7c2a8dc41f1da4f68b2d Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Fri, 29 Oct 2021 12:29:44 -0400 Subject: [PATCH] del iterator on_run_end() (#9915) --- CHANGELOG.md | 2 +- .../loops/dataloader/evaluation_loop.py | 4 +- .../loops/epoch/training_epoch_loop.py | 2 - pytorch_lightning/loops/fit_loop.py | 4 +- pytorch_lightning/trainer/supporters.py | 15 +++++++- pytorch_lightning/utilities/fetching.py | 6 ++- tests/loops/test_evaluation_loop.py | 2 +- tests/loops/test_loops.py | 37 ++++++++++++++++++- tests/loops/test_training_loop.py | 1 - 9 files changed, 63 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d85445531..67516b8294 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -595,7 +595,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) -- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) +- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915)) - Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432)) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index c2a1a5e786..6140bd60d6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -101,7 +101,9 @@ class EvaluationLoop(DataLoaderLoop): dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) - dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) + self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( + dataloader, dataloader_idx=dataloader_idx + ) dl_max_batches = self._max_batches[dataloader_idx] dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index dbe163aa33..21d89a8be8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -306,8 +306,6 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) - self._dataloader_iter = None - # if fault tolerant is enabled and process has been notified, exit. self.trainer._exit_gracefully_on_signal() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 024ff36a6f..df6634c963 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -209,7 +209,9 @@ class FitLoop(Loop): self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False - if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)): + if self.trainer.train_dataloader is not None and callable( + getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) + ): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 87e5f9f4f7..816f4da38f 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.utils.data import Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections @@ -491,6 +491,19 @@ class CombinedLoader: def __len__(self) -> int: return self._calc_num_batches(self.loaders) + @staticmethod + def _shutdown_workers_and_reset_iterator(dataloader) -> None: + if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): + dataloader._iterator._shutdown_workers() + dataloader._iterator = None + + def reset(self): + if self._iterator: + self._iterator._loader_iters = None + if self.loaders is not None: + apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator) + self._iterator = None + class CombinedLoaderIterator: """Custom Iterator returning data from multple loaders, and allows sampling in parallel.""" diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 689c2bff8e..fd9baf3e9c 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -204,9 +204,13 @@ class AbstractDataFetcher(ABC): def reset(self) -> None: self.batches: List = [] - self.dataloader: Optional[Iterable] self.fetched: int = 0 self.done: bool = False + if isinstance(self.dataloader, CombinedLoader): + self.dataloader.reset() + if isinstance(self.dataloader, DataLoader): + CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) + self.dataloader_iter = None def teardown(self) -> None: self.reset() diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 2b67dec18d..d6b2c15553 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -14,7 +14,7 @@ from unittest import mock import torch -from torch.utils.data import DataLoader +from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationEpochLoop diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index a1efa838e9..dd390ab493 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,7 +20,7 @@ from unittest.mock import ANY import pytest import torch -from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import LightningModule, Trainer @@ -909,3 +909,38 @@ def test_fit_can_fail_during_validation(train_datasets, val_datasets, val_check_ expected[val_batch_progress]["total"]["ready"] += 1 expected[val_batch_progress]["total"]["started"] += 1 assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress] + + +@RunIf(min_torch="1.8.0") +@pytest.mark.parametrize("persistent_workers", (True, False)) +def test_workers_are_shutdown(tmpdir, persistent_workers): + # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` + # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + + class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): + def __init__(self, *args, dataloader: DataLoader, **kwargs): + super().__init__(*args, **kwargs) + self.dataloader = dataloader + + def _shutdown_workers(self): + setattr(self.dataloader, "has_shutdown_workers", True) + super()._shutdown_workers() + + class TestDataLoader(DataLoader): + def _get_iterator(self): + if self.num_workers == 0: + return super()._get_iterator() + else: + self.check_worker_number_rationality() + return _TestMultiProcessingDataLoaderIter(self, dataloader=self) + + train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) + trainer.fit(model, train_dataloader, val_dataloader) + assert train_dataloader.has_shutdown_workers + assert val_dataloader.has_shutdown_workers + assert train_dataloader._iterator is None + assert val_dataloader._iterator is None diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index ebfe0d4762..86801f5626 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest import torch