From 20ebb5ccc4b460d140cff2cc5752c93e8142faf9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 14 Sep 2021 23:47:36 -0700 Subject: [PATCH] Reset datafetcher references in teardown (#9387) * Free references to data fetcher in data connector teardown --- CHANGELOG.md | 3 +++ .../trainer/connectors/data_connector.py | 4 ++++ tests/utilities/test_fetching.py | 17 ++++++++++++++--- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c88f7835c8..6c09c2f321 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -360,6 +360,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367)) +- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387)) + + - Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5673506a64..589906d2bb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -248,12 +248,16 @@ class DataConnector: def teardown(self) -> None: if self.train_data_fetcher: self.train_data_fetcher.teardown() + self.train_data_fetcher = None if self.validate_data_fetcher: self.validate_data_fetcher.teardown() + self.validate_data_fetcher = None if self.test_data_fetcher: self.test_data_fetcher.teardown() + self.test_data_fetcher = None if self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() + self.sanity_check_data_fetcher = None class _PatchDataLoader: diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 86d04af9d2..420f580260 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -21,7 +21,7 @@ import torch from torch import tensor from torch.utils.data import DataLoader, Dataset, IterableDataset -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher @@ -179,6 +179,16 @@ def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() + class AssertFetcher(Callback): + def __init__(self, check_inter_batch: bool): + self._check_inter_batch = check_inter_batch + + def on_train_epoch_end(self, trainer, lightning_module): + if self._check_inter_batch: + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) + else: + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) + trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, @@ -186,6 +196,7 @@ def test_trainer_num_prefetch_batches(tmpdir): limit_train_batches=4, limit_val_batches=0, num_sanity_val_steps=0, + callbacks=[AssertFetcher(check_inter_batch=True)], ) with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): @@ -193,16 +204,16 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer = Trainer(**trainer_kwargs) trainer.fit(model) t1 = time() - assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) global_step = trainer.global_step torch.cuda.synchronize() + trainer_kwargs["callbacks"] = [AssertFetcher(check_inter_batch=False)] + t2 = time() trainer = Trainer(**trainer_kwargs) trainer.fit(model) t3 = time() - assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) assert global_step == trainer.global_step == 4 ratio = (t3 - t2) / (t1 - t0)