diff --git a/CHANGELOG.md b/CHANGELOG.md index 22113b9d6d..ddaf4288a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) +- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) + + ### Deprecated diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3fae52cde1..aed9de3355 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -75,7 +75,11 @@ class LightningArgumentParser(ArgumentParser): class SaveConfigCallback(Callback): - """Saves a LightningCLI config to the log_dir when training starts""" + """Saves a LightningCLI config to the log_dir when training starts + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ def __init__( self, @@ -90,6 +94,11 @@ class SaveConfigCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: log_dir = trainer.log_dir or trainer.default_root_dir config_path = os.path.join(log_dir, self.config_filename) + if os.path.isfile(config_path): + raise RuntimeError( + f'{self.__class__.__name__} expected {config_path} to not exist. ' + 'Aborting to avoid overwriting results of a previous run.' + ) self.parser.save(self.config, config_path, skip_none=False) @@ -231,7 +240,7 @@ class LightningCLI: self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) else: self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) - if self.save_config_callback is not None: + if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']: config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename) self.config_init['trainer']['callbacks'].append(config_callback) self.trainer = self.trainer_class(**self.config_init['trainer']) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e1623b15c2..40ddbf9b16 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -328,6 +328,31 @@ def test_lightning_cli_args(tmpdir): assert config['trainer'] == cli.config['trainer'] +def test_lightning_cli_save_config_cases(tmpdir): + + config_path = tmpdir / 'config.yaml' + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.logger=False', + '--trainer.fast_dev_run=1', + ] + + # With fast_dev_run!=False config should not be saved + with mock.patch('sys.argv', ['any.py'] + cli_args): + LightningCLI(BoringModel) + assert not os.path.isfile(config_path) + + # With fast_dev_run==False config should be saved + cli_args[-1] = '--trainer.max_epochs=1' + with mock.patch('sys.argv', ['any.py'] + cli_args): + LightningCLI(BoringModel) + assert os.path.isfile(config_path) + + # If run again on same directory exception should be raised since config file already exists + with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError): + LightningCLI(BoringModel) + + def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict(