[IPU] Add reset dataloader hooks to training type plugin 3/n (#7861)
* Add hooks * Add tests for hooks * Add changelog * Test changes, add typing
This commit is contained in:
parent
d1becce4c1
commit
6388c29e87
|
@ -59,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))
|
||||
|
||||
|
||||
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
|
||||
|
|
|
@ -460,6 +460,22 @@ class Accelerator:
|
|||
"""
|
||||
return self.training_type_plugin.process_dataloader(dataloader)
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the train dataloader."""
|
||||
return self.training_type_plugin.on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the val dataloader."""
|
||||
return self.training_type_plugin.on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the test dataloader."""
|
||||
return self.training_type_plugin.on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the predict dataloader."""
|
||||
return self.training_type_plugin.on_reset_predict_dataloader(dataloader)
|
||||
|
||||
@property
|
||||
def results(self) -> Any:
|
||||
"""
|
||||
|
|
|
@ -195,6 +195,22 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
"""
|
||||
return dataloader
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the train dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the val dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the test dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the predict dataloader."""
|
||||
return dataloader
|
||||
|
||||
def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule'):
|
||||
return trainer.init_optimizers(model)
|
||||
|
||||
|
|
|
@ -261,6 +261,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
|
||||
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
|
||||
|
||||
# allow accelerator to modify dataloader
|
||||
self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader)
|
||||
|
||||
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
|
||||
|
||||
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
|
||||
|
@ -361,6 +364,10 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# add worker_init_fn for correct seeding in worker processes
|
||||
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
|
||||
|
||||
# allow accelerator to modify dataloader
|
||||
hook_name = f"on_reset_{mode}_dataloader"
|
||||
dataloaders = getattr(self.accelerator, hook_name)(dataloaders)
|
||||
|
||||
loader_num_batches = []
|
||||
|
||||
# determine number of batches
|
||||
|
|
|
@ -7,6 +7,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.accelerators import CPUAccelerator
|
||||
from pytorch_lightning.plugins import SingleDevicePlugin
|
||||
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
@ -50,3 +51,112 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
|
|||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu")))
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
|
||||
"""
|
||||
Ensure data-loader hooks are called using an Accelerator.
|
||||
"""
|
||||
|
||||
class CustomAccelerator(CPUAccelerator):
|
||||
train_count: int = 0
|
||||
val_count: int = 0
|
||||
test_count: int = 0
|
||||
predict_count: int = 0
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader):
|
||||
self.train_count += 1
|
||||
assert self.lightning_module.trainer.training
|
||||
return super().on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader):
|
||||
self.val_count += 1
|
||||
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
|
||||
return super().on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader):
|
||||
self.test_count += 1
|
||||
assert self.lightning_module.trainer.testing
|
||||
return super().on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader):
|
||||
self.predict_count += 1
|
||||
assert self.lightning_module.trainer.predicting
|
||||
return super().on_reset_predict_dataloader(dataloader)
|
||||
|
||||
model = BoringModel()
|
||||
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu')))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, dataloaders=model.test_dataloader())
|
||||
# assert that all loader hooks were called
|
||||
assert accelerator.train_count == 1
|
||||
assert accelerator.val_count == 1 # only called once during the entire session
|
||||
assert accelerator.test_count == 1
|
||||
assert accelerator.predict_count == 1
|
||||
|
||||
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device('cpu')))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
# assert val/test/predict loader hooks were called
|
||||
assert accelerator.val_count == 1
|
||||
assert accelerator.test_count == 1
|
||||
assert accelerator.predict_count == 1
|
||||
|
||||
|
||||
def test_plugin_on_reset_dataloader_hooks(tmpdir):
|
||||
"""
|
||||
Ensure data-loader hooks are called using a Plugin.
|
||||
"""
|
||||
|
||||
class CustomPlugin(SingleDevicePlugin):
|
||||
train_count: int = 0
|
||||
val_count: int = 0
|
||||
test_count: int = 0
|
||||
predict_count: int = 0
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader):
|
||||
self.train_count += 1
|
||||
assert self.lightning_module.trainer.training
|
||||
return super().on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader):
|
||||
self.val_count += 1
|
||||
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
|
||||
return super().on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader):
|
||||
self.test_count += 1
|
||||
assert self.lightning_module.trainer.testing
|
||||
return super().on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader):
|
||||
self.predict_count += 1
|
||||
assert self.lightning_module.trainer.predicting
|
||||
return super().on_reset_predict_dataloader(dataloader)
|
||||
|
||||
plugin = CustomPlugin(device=torch.device('cpu'))
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, dataloaders=model.test_dataloader())
|
||||
# assert that all loader hooks were called
|
||||
assert plugin.train_count == 1
|
||||
assert plugin.val_count == 1 # only called once during the entire session
|
||||
assert plugin.test_count == 1
|
||||
assert plugin.predict_count == 1
|
||||
plugin = CustomPlugin(device=torch.device('cpu'))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
# assert val/test/predict loader hooks were called
|
||||
assert plugin.val_count == 1
|
||||
assert plugin.test_count == 1
|
||||
assert plugin.predict_count == 1
|
||||
|
|
Loading…
Reference in New Issue