diff --git a/CHANGELOG.md b/CHANGELOG.md index e8008ec06b..50d5ef6de2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741)) +- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861)) + + ### Changed - Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4ea017ae0c..2938feee83 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -460,6 +460,22 @@ class Accelerator: """ return self.training_type_plugin.process_dataloader(dataloader) + def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the train dataloader.""" + return self.training_type_plugin.on_reset_train_dataloader(dataloader) + + def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the val dataloader.""" + return self.training_type_plugin.on_reset_val_dataloader(dataloader) + + def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the test dataloader.""" + return self.training_type_plugin.on_reset_test_dataloader(dataloader) + + def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the predict dataloader.""" + return self.training_type_plugin.on_reset_predict_dataloader(dataloader) + @property def results(self) -> Any: """ diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c5b3e7eea3..f965a56f23 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -195,6 +195,22 @@ class TrainingTypePlugin(Plugin, ABC): """ return dataloader + def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the train dataloader.""" + return dataloader + + def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the val dataloader.""" + return dataloader + + def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the test dataloader.""" + return dataloader + + def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Called before resetting the predict dataloader.""" + return dataloader + def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule'): return trainer.init_optimizers(model) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 53c9b07dff..a16ac0c7f5 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -261,6 +261,9 @@ class TrainerDataLoadingMixin(ABC): # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode) + # allow accelerator to modify dataloader + self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader) + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -361,6 +364,10 @@ class TrainerDataLoadingMixin(ABC): # add worker_init_fn for correct seeding in worker processes apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) + # allow accelerator to modify dataloader + hook_name = f"on_reset_{mode}_dataloader" + dataloaders = getattr(self.accelerator, hook_name)(dataloaders) + loader_num_batches = [] # determine number of batches diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index c7d7f98ae9..7be1c6b9d1 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -7,6 +7,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -50,3 +51,112 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) trainer.fit(model) + + +def test_accelerator_on_reset_dataloader_hooks(tmpdir): + """ + Ensure data-loader hooks are called using an Accelerator. + """ + + class CustomAccelerator(CPUAccelerator): + train_count: int = 0 + val_count: int = 0 + test_count: int = 0 + predict_count: int = 0 + + def on_reset_train_dataloader(self, dataloader): + self.train_count += 1 + assert self.lightning_module.trainer.training + return super().on_reset_train_dataloader(dataloader) + + def on_reset_val_dataloader(self, dataloader): + self.val_count += 1 + assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating + return super().on_reset_val_dataloader(dataloader) + + def on_reset_test_dataloader(self, dataloader): + self.test_count += 1 + assert self.lightning_module.trainer.testing + return super().on_reset_test_dataloader(dataloader) + + def on_reset_predict_dataloader(self, dataloader): + self.predict_count += 1 + assert self.lightning_module.trainer.predicting + return super().on_reset_predict_dataloader(dataloader) + + model = BoringModel() + accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu'))) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model, dataloaders=model.test_dataloader()) + # assert that all loader hooks were called + assert accelerator.train_count == 1 + assert accelerator.val_count == 1 # only called once during the entire session + assert accelerator.test_count == 1 + assert accelerator.predict_count == 1 + + accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu'))) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + # assert val/test/predict loader hooks were called + assert accelerator.val_count == 1 + assert accelerator.test_count == 1 + assert accelerator.predict_count == 1 + + +def test_plugin_on_reset_dataloader_hooks(tmpdir): + """ + Ensure data-loader hooks are called using a Plugin. + """ + + class CustomPlugin(SingleDevicePlugin): + train_count: int = 0 + val_count: int = 0 + test_count: int = 0 + predict_count: int = 0 + + def on_reset_train_dataloader(self, dataloader): + self.train_count += 1 + assert self.lightning_module.trainer.training + return super().on_reset_train_dataloader(dataloader) + + def on_reset_val_dataloader(self, dataloader): + self.val_count += 1 + assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating + return super().on_reset_val_dataloader(dataloader) + + def on_reset_test_dataloader(self, dataloader): + self.test_count += 1 + assert self.lightning_module.trainer.testing + return super().on_reset_test_dataloader(dataloader) + + def on_reset_predict_dataloader(self, dataloader): + self.predict_count += 1 + assert self.lightning_module.trainer.predicting + return super().on_reset_predict_dataloader(dataloader) + + plugin = CustomPlugin(device=torch.device('cpu')) + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model, dataloaders=model.test_dataloader()) + # assert that all loader hooks were called + assert plugin.train_count == 1 + assert plugin.val_count == 1 # only called once during the entire session + assert plugin.test_count == 1 + assert plugin.predict_count == 1 + plugin = CustomPlugin(device=torch.device('cpu')) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + # assert val/test/predict loader hooks were called + assert plugin.val_count == 1 + assert plugin.test_count == 1 + assert plugin.predict_count == 1