[IPU] Add reset dataloader hooks to training type plugin 3/n (#7861)

* Add hooks

* Add tests for hooks

* Add changelog

* Test changes, add typing
This commit is contained in:
Sean Naren 2021-06-07 11:37:09 +01:00 committed by GitHub
parent d1becce4c1
commit 6388c29e87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 152 additions and 0 deletions

View File

@ -59,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
### Changed
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

View File

@ -460,6 +460,22 @@ class Accelerator:
"""
return self.training_type_plugin.process_dataloader(dataloader)
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return self.training_type_plugin.on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return self.training_type_plugin.on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return self.training_type_plugin.on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return self.training_type_plugin.on_reset_predict_dataloader(dataloader)
@property
def results(self) -> Any:
"""

View File

@ -195,6 +195,22 @@ class TrainingTypePlugin(Plugin, ABC):
"""
return dataloader
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return dataloader
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return dataloader
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return dataloader
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return dataloader
def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule'):
return trainer.init_optimizers(model)

View File

@ -261,6 +261,9 @@ class TrainerDataLoadingMixin(ABC):
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
# allow accelerator to modify dataloader
self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
@ -361,6 +364,10 @@ class TrainerDataLoadingMixin(ABC):
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
# allow accelerator to modify dataloader
hook_name = f"on_reset_{mode}_dataloader"
dataloaders = getattr(self.accelerator, hook_name)(dataloaders)
loader_num_batches = []
# determine number of batches

View File

@ -7,6 +7,7 @@ from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
@ -50,3 +51,112 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu")))
trainer.fit(model)
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
"""
Ensure data-loader hooks are called using an Accelerator.
"""
class CustomAccelerator(CPUAccelerator):
train_count: int = 0
val_count: int = 0
test_count: int = 0
predict_count: int = 0
def on_reset_train_dataloader(self, dataloader):
self.train_count += 1
assert self.lightning_module.trainer.training
return super().on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader):
self.val_count += 1
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
return super().on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader):
self.test_count += 1
assert self.lightning_module.trainer.testing
return super().on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader):
self.predict_count += 1
assert self.lightning_module.trainer.predicting
return super().on_reset_predict_dataloader(dataloader)
model = BoringModel()
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu')))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, dataloaders=model.test_dataloader())
# assert that all loader hooks were called
assert accelerator.train_count == 1
assert accelerator.val_count == 1 # only called once during the entire session
assert accelerator.test_count == 1
assert accelerator.predict_count == 1
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu')))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
# assert val/test/predict loader hooks were called
assert accelerator.val_count == 1
assert accelerator.test_count == 1
assert accelerator.predict_count == 1
def test_plugin_on_reset_dataloader_hooks(tmpdir):
"""
Ensure data-loader hooks are called using a Plugin.
"""
class CustomPlugin(SingleDevicePlugin):
train_count: int = 0
val_count: int = 0
test_count: int = 0
predict_count: int = 0
def on_reset_train_dataloader(self, dataloader):
self.train_count += 1
assert self.lightning_module.trainer.training
return super().on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader):
self.val_count += 1
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
return super().on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader):
self.test_count += 1
assert self.lightning_module.trainer.testing
return super().on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader):
self.predict_count += 1
assert self.lightning_module.trainer.predicting
return super().on_reset_predict_dataloader(dataloader)
plugin = CustomPlugin(device=torch.device('cpu'))
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, dataloaders=model.test_dataloader())
# assert that all loader hooks were called
assert plugin.train_count == 1
assert plugin.val_count == 1 # only called once during the entire session
assert plugin.test_count == 1
assert plugin.predict_count == 1
plugin = CustomPlugin(device=torch.device('cpu'))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
# assert val/test/predict loader hooks were called
assert plugin.val_count == 1
assert plugin.test_count == 1
assert plugin.predict_count == 1