Reset datafetcher references in teardown (#9387)

* Free references to data fetcher in data connector teardown
This commit is contained in:
ananthsub 2021-09-14 23:47:36 -07:00 committed by GitHub
parent 637f59f1d2
commit 20ebb5ccc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 3 deletions

View File

@ -360,6 +360,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387))
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))

View File

@ -248,12 +248,16 @@ class DataConnector:
def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
self.train_data_fetcher = None
if self.validate_data_fetcher:
self.validate_data_fetcher.teardown()
self.validate_data_fetcher = None
if self.test_data_fetcher:
self.test_data_fetcher.teardown()
self.test_data_fetcher = None
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
self.sanity_check_data_fetcher = None
class _PatchDataLoader:

View File

@ -21,7 +21,7 @@ import torch
from torch import tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
@ -179,6 +179,16 @@ def test_trainer_num_prefetch_batches(tmpdir):
model = RecommenderModel()
class AssertFetcher(Callback):
def __init__(self, check_inter_batch: bool):
self._check_inter_batch = check_inter_batch
def on_train_epoch_end(self, trainer, lightning_module):
if self._check_inter_batch:
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
else:
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
trainer_kwargs = dict(
default_root_dir=tmpdir,
max_epochs=1,
@ -186,6 +196,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
limit_train_batches=4,
limit_val_batches=0,
num_sanity_val_steps=0,
callbacks=[AssertFetcher(check_inter_batch=True)],
)
with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}):
@ -193,16 +204,16 @@ def test_trainer_num_prefetch_batches(tmpdir):
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t1 = time()
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
global_step = trainer.global_step
torch.cuda.synchronize()
trainer_kwargs["callbacks"] = [AssertFetcher(check_inter_batch=False)]
t2 = time()
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t3 = time()
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
assert global_step == trainer.global_step == 4
ratio = (t3 - t2) / (t1 - t0)