Improvements related to save of config file by LightningCLI (#7963)

* - Exclude SaveConfigCallback for fast_dev_run=True.
- SaveConfigCallback give a clearer message if config file already exists.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* - Added unit test
- Added entry in changelog
- Improved save config docstring

* Fix log line

* Fixes

* Fix changelog entry

* Update pytorch_lightning/utilities/cli.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Suggested fixed change

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2021-06-15 23:26:39 +02:00 committed by GitHub
parent 971908a1aa
commit b2e9fa814f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 2 deletions

View File

@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))
- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))
### Deprecated

View File

@ -75,7 +75,11 @@ class LightningArgumentParser(ArgumentParser):
class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts"""
"""Saves a LightningCLI config to the log_dir when training starts
Raises:
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
"""
def __init__(
self,
@ -90,6 +94,11 @@ class SaveConfigCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if os.path.isfile(config_path):
raise RuntimeError(
f'{self.__class__.__name__} expected {config_path} to not exist. '
'Aborting to avoid overwriting results of a previous run.'
)
self.parser.save(self.config, config_path, skip_none=False)
@ -231,7 +240,7 @@ class LightningCLI:
self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks'])
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback is not None:
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])

View File

@ -328,6 +328,31 @@ def test_lightning_cli_args(tmpdir):
assert config['trainer'] == cli.config['trainer']
def test_lightning_cli_save_config_cases(tmpdir):
config_path = tmpdir / 'config.yaml'
cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.logger=False',
'--trainer.fast_dev_run=1',
]
# With fast_dev_run!=False config should not be saved
with mock.patch('sys.argv', ['any.py'] + cli_args):
LightningCLI(BoringModel)
assert not os.path.isfile(config_path)
# With fast_dev_run==False config should be saved
cli_args[-1] = '--trainer.max_epochs=1'
with mock.patch('sys.argv', ['any.py'] + cli_args):
LightningCLI(BoringModel)
assert os.path.isfile(config_path)
# If run again on same directory exception should be raised since config file already exists
with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError):
LightningCLI(BoringModel)
def test_lightning_cli_config_and_subclass_mode(tmpdir):
config = dict(