Add `LightningCLI(save_config_overwrite=False|True)` (#8059)

This commit is contained in:
Carlos Mocholí 2021-06-21 17:58:02 +02:00 committed by GitHub
parent d1efae2e47
commit d9bf9759fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 6 deletions

View File

@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))
- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059))
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))

View File

@ -89,20 +89,24 @@ class SaveConfigCallback(Callback):
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config_filename: str,
overwrite: bool = False,
) -> None:
self.parser = parser
self.config = config
self.config_filename = config_filename
self.overwrite = overwrite
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if os.path.isfile(config_path):
if not self.overwrite and os.path.isfile(config_path):
raise RuntimeError(
f'{self.__class__.__name__} expected {config_path} to not exist. '
'Aborting to avoid overwriting results of a previous run.'
f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting'
' results of a previous run. You can delete the previous config file,'
' set `LightningCLI(save_config_callback=None)` to disable config saving,'
' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.'
)
self.parser.save(self.config, config_path, skip_none=False)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)
class LightningCLI:
@ -112,8 +116,9 @@ class LightningCLI:
self,
model_class: Type[LightningModule],
datamodule_class: Type[LightningDataModule] = None,
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = 'config.yaml',
save_config_overwrite: bool = False,
trainer_class: Type[Trainer] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
@ -150,6 +155,8 @@ class LightningCLI:
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
save_config_callback: A callback class to save the training config.
save_config_filename: Filename for the config file.
save_config_overwrite: Whether to overwrite an existing config file.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
@ -173,6 +180,7 @@ class LightningCLI:
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
self.seed_everything_default = seed_everything_default
@ -246,7 +254,9 @@ class LightningCLI:
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
config_callback = self.save_config_callback(
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])

View File

@ -603,3 +603,14 @@ def test_lightning_cli_link_arguments(tmpdir):
assert cli.model.batch_size == 8
assert cli.model.num_classes == 5
def test_cli_config_overwrite(tmpdir):
trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}
with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch('sys.argv', ['any.py']):
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)