Add `LightningCLI(save_config_overwrite=False|True)` (#8059)
This commit is contained in:
parent
d1efae2e47
commit
d9bf9759fb
|
@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))
|
||||
|
||||
|
||||
- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059))
|
||||
|
||||
|
||||
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
|
||||
|
||||
|
||||
|
|
|
@ -89,20 +89,24 @@ class SaveConfigCallback(Callback):
|
|||
parser: LightningArgumentParser,
|
||||
config: Union[Namespace, Dict[str, Any]],
|
||||
config_filename: str,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
self.parser = parser
|
||||
self.config = config
|
||||
self.config_filename = config_filename
|
||||
self.overwrite = overwrite
|
||||
|
||||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
||||
log_dir = trainer.log_dir or trainer.default_root_dir
|
||||
config_path = os.path.join(log_dir, self.config_filename)
|
||||
if os.path.isfile(config_path):
|
||||
if not self.overwrite and os.path.isfile(config_path):
|
||||
raise RuntimeError(
|
||||
f'{self.__class__.__name__} expected {config_path} to not exist. '
|
||||
'Aborting to avoid overwriting results of a previous run.'
|
||||
f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting'
|
||||
' results of a previous run. You can delete the previous config file,'
|
||||
' set `LightningCLI(save_config_callback=None)` to disable config saving,'
|
||||
' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.'
|
||||
)
|
||||
self.parser.save(self.config, config_path, skip_none=False)
|
||||
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)
|
||||
|
||||
|
||||
class LightningCLI:
|
||||
|
@ -112,8 +116,9 @@ class LightningCLI:
|
|||
self,
|
||||
model_class: Type[LightningModule],
|
||||
datamodule_class: Type[LightningDataModule] = None,
|
||||
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
|
||||
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
|
||||
save_config_filename: str = 'config.yaml',
|
||||
save_config_overwrite: bool = False,
|
||||
trainer_class: Type[Trainer] = Trainer,
|
||||
trainer_defaults: Dict[str, Any] = None,
|
||||
seed_everything_default: int = None,
|
||||
|
@ -150,6 +155,8 @@ class LightningCLI:
|
|||
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
|
||||
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
|
||||
save_config_callback: A callback class to save the training config.
|
||||
save_config_filename: Filename for the config file.
|
||||
save_config_overwrite: Whether to overwrite an existing config file.
|
||||
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
|
||||
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
|
||||
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
|
||||
|
@ -173,6 +180,7 @@ class LightningCLI:
|
|||
self.datamodule_class = datamodule_class
|
||||
self.save_config_callback = save_config_callback
|
||||
self.save_config_filename = save_config_filename
|
||||
self.save_config_overwrite = save_config_overwrite
|
||||
self.trainer_class = trainer_class
|
||||
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
|
||||
self.seed_everything_default = seed_everything_default
|
||||
|
@ -246,7 +254,9 @@ class LightningCLI:
|
|||
else:
|
||||
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
|
||||
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
|
||||
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
|
||||
config_callback = self.save_config_callback(
|
||||
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
|
||||
)
|
||||
self.config_init['trainer']['callbacks'].append(config_callback)
|
||||
self.trainer = self.trainer_class(**self.config_init['trainer'])
|
||||
|
||||
|
|
|
@ -603,3 +603,14 @@ def test_lightning_cli_link_arguments(tmpdir):
|
|||
|
||||
assert cli.model.batch_size == 8
|
||||
assert cli.model.num_classes == 5
|
||||
|
||||
|
||||
def test_cli_config_overwrite(tmpdir):
|
||||
trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}
|
||||
|
||||
with mock.patch('sys.argv', ['any.py']):
|
||||
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
|
||||
with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'):
|
||||
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
|
||||
with mock.patch('sys.argv', ['any.py']):
|
||||
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
|
||||
|
|
Loading…
Reference in New Issue