Reset datafetcher references in teardown (#9387)
* Free references to data fetcher in data connector teardown
This commit is contained in:
parent
637f59f1d2
commit
20ebb5ccc4
|
@ -360,6 +360,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
|
||||
|
||||
|
||||
- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387))
|
||||
|
||||
|
||||
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))
|
||||
|
||||
|
||||
|
|
|
@ -248,12 +248,16 @@ class DataConnector:
|
|||
def teardown(self) -> None:
|
||||
if self.train_data_fetcher:
|
||||
self.train_data_fetcher.teardown()
|
||||
self.train_data_fetcher = None
|
||||
if self.validate_data_fetcher:
|
||||
self.validate_data_fetcher.teardown()
|
||||
self.validate_data_fetcher = None
|
||||
if self.test_data_fetcher:
|
||||
self.test_data_fetcher.teardown()
|
||||
self.test_data_fetcher = None
|
||||
if self.sanity_check_data_fetcher:
|
||||
self.sanity_check_data_fetcher.teardown()
|
||||
self.sanity_check_data_fetcher = None
|
||||
|
||||
|
||||
class _PatchDataLoader:
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
|||
from torch import tensor
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
||||
|
@ -179,6 +179,16 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
|
||||
model = RecommenderModel()
|
||||
|
||||
class AssertFetcher(Callback):
|
||||
def __init__(self, check_inter_batch: bool):
|
||||
self._check_inter_batch = check_inter_batch
|
||||
|
||||
def on_train_epoch_end(self, trainer, lightning_module):
|
||||
if self._check_inter_batch:
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
|
||||
else:
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
|
||||
|
||||
trainer_kwargs = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
|
@ -186,6 +196,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
limit_train_batches=4,
|
||||
limit_val_batches=0,
|
||||
num_sanity_val_steps=0,
|
||||
callbacks=[AssertFetcher(check_inter_batch=True)],
|
||||
)
|
||||
|
||||
with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}):
|
||||
|
@ -193,16 +204,16 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(model)
|
||||
t1 = time()
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
|
||||
global_step = trainer.global_step
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trainer_kwargs["callbacks"] = [AssertFetcher(check_inter_batch=False)]
|
||||
|
||||
t2 = time()
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(model)
|
||||
t3 = time()
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
|
||||
|
||||
assert global_step == trainer.global_step == 4
|
||||
ratio = (t3 - t2) / (t1 - t0)
|
||||
|
|
Loading…
Reference in New Issue