From b2e9fa814fde0f15711c053e3c091a076496fd8c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 15 Jun 2021 23:26:39 +0200 Subject: [PATCH] Improvements related to save of config file by LightningCLI (#7963) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * - Exclude SaveConfigCallback for fast_dev_run=True. - SaveConfigCallback give a clearer message if config file already exists. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * - Added unit test - Added entry in changelog - Improved save config docstring * Fix log line * Fixes * Fix changelog entry * Update pytorch_lightning/utilities/cli.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Carlos MocholĂ­ * Suggested fixed change Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/cli.py | 13 +++++++++++-- tests/utilities/test_cli.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22113b9d6d..ddaf4288a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) +- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) + + ### Deprecated diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3fae52cde1..aed9de3355 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -75,7 +75,11 @@ class LightningArgumentParser(ArgumentParser): class SaveConfigCallback(Callback): - """Saves a LightningCLI config to the log_dir when training starts""" + """Saves a LightningCLI config to the log_dir when training starts + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ def __init__( self, @@ -90,6 +94,11 @@ class SaveConfigCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: log_dir = trainer.log_dir or trainer.default_root_dir config_path = os.path.join(log_dir, self.config_filename) + if os.path.isfile(config_path): + raise RuntimeError( + f'{self.__class__.__name__} expected {config_path} to not exist. ' + 'Aborting to avoid overwriting results of a previous run.' + ) self.parser.save(self.config, config_path, skip_none=False) @@ -231,7 +240,7 @@ class LightningCLI: self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) else: self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) - if self.save_config_callback is not None: + if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']: config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename) self.config_init['trainer']['callbacks'].append(config_callback) self.trainer = self.trainer_class(**self.config_init['trainer']) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e1623b15c2..40ddbf9b16 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -328,6 +328,31 @@ def test_lightning_cli_args(tmpdir): assert config['trainer'] == cli.config['trainer'] +def test_lightning_cli_save_config_cases(tmpdir): + + config_path = tmpdir / 'config.yaml' + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.logger=False', + '--trainer.fast_dev_run=1', + ] + + # With fast_dev_run!=False config should not be saved + with mock.patch('sys.argv', ['any.py'] + cli_args): + LightningCLI(BoringModel) + assert not os.path.isfile(config_path) + + # With fast_dev_run==False config should be saved + cli_args[-1] = '--trainer.max_epochs=1' + with mock.patch('sys.argv', ['any.py'] + cli_args): + LightningCLI(BoringModel) + assert os.path.isfile(config_path) + + # If run again on same directory exception should be raised since config file already exists + with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError): + LightningCLI(BoringModel) + + def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict(