diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f1d85b777..763abd754a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741)) +- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059)) + + - Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 1f1788393b..387bbdeb85 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -89,20 +89,24 @@ class SaveConfigCallback(Callback): parser: LightningArgumentParser, config: Union[Namespace, Dict[str, Any]], config_filename: str, + overwrite: bool = False, ) -> None: self.parser = parser self.config = config self.config_filename = config_filename + self.overwrite = overwrite def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: log_dir = trainer.log_dir or trainer.default_root_dir config_path = os.path.join(log_dir, self.config_filename) - if os.path.isfile(config_path): + if not self.overwrite and os.path.isfile(config_path): raise RuntimeError( - f'{self.__class__.__name__} expected {config_path} to not exist. ' - 'Aborting to avoid overwriting results of a previous run.' + f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting' + ' results of a previous run. You can delete the previous config file,' + ' set `LightningCLI(save_config_callback=None)` to disable config saving,' + ' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.' ) - self.parser.save(self.config, config_path, skip_none=False) + self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite) class LightningCLI: @@ -112,8 +116,9 @@ class LightningCLI: self, model_class: Type[LightningModule], datamodule_class: Type[LightningDataModule] = None, - save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, + save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, save_config_filename: str = 'config.yaml', + save_config_overwrite: bool = False, trainer_class: Type[Trainer] = Trainer, trainer_defaults: Dict[str, Any] = None, seed_everything_default: int = None, @@ -150,6 +155,8 @@ class LightningCLI: model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class. save_config_callback: A callback class to save the training config. + save_config_filename: Filename for the config file. + save_config_overwrite: Whether to overwrite an existing config file. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` @@ -173,6 +180,7 @@ class LightningCLI: self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback self.save_config_filename = save_config_filename + self.save_config_overwrite = save_config_overwrite self.trainer_class = trainer_class self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults self.seed_everything_default = seed_everything_default @@ -246,7 +254,9 @@ class LightningCLI: else: self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']: - config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename) + config_callback = self.save_config_callback( + self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite + ) self.config_init['trainer']['callbacks'].append(config_callback) self.trainer = self.trainer_class(**self.config_init['trainer']) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 458726662d..4e37f511f0 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -603,3 +603,14 @@ def test_lightning_cli_link_arguments(tmpdir): assert cli.model.batch_size == 8 assert cli.model.num_classes == 5 + + +def test_cli_config_overwrite(tmpdir): + trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1} + + with mock.patch('sys.argv', ['any.py']): + LightningCLI(BoringModel, trainer_defaults=trainer_defaults) + with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'): + LightningCLI(BoringModel, trainer_defaults=trainer_defaults) + with mock.patch('sys.argv', ['any.py']): + LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)