Improvements related to save of config file by LightningCLI (#7963)
* - Exclude SaveConfigCallback for fast_dev_run=True. - SaveConfigCallback give a clearer message if config file already exists. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * - Added unit test - Added entry in changelog - Improved save config docstring * Fix log line * Fixes * Fix changelog entry * Update pytorch_lightning/utilities/cli.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Suggested fixed change Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
971908a1aa
commit
b2e9fa814f
|
@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))
|
||||
|
||||
|
||||
- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
|
|
|
@ -75,7 +75,11 @@ class LightningArgumentParser(ArgumentParser):
|
|||
|
||||
|
||||
class SaveConfigCallback(Callback):
|
||||
"""Saves a LightningCLI config to the log_dir when training starts"""
|
||||
"""Saves a LightningCLI config to the log_dir when training starts
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -90,6 +94,11 @@ class SaveConfigCallback(Callback):
|
|||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
||||
log_dir = trainer.log_dir or trainer.default_root_dir
|
||||
config_path = os.path.join(log_dir, self.config_filename)
|
||||
if os.path.isfile(config_path):
|
||||
raise RuntimeError(
|
||||
f'{self.__class__.__name__} expected {config_path} to not exist. '
|
||||
'Aborting to avoid overwriting results of a previous run.'
|
||||
)
|
||||
self.parser.save(self.config, config_path, skip_none=False)
|
||||
|
||||
|
||||
|
@ -231,7 +240,7 @@ class LightningCLI:
|
|||
self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks'])
|
||||
else:
|
||||
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
|
||||
if self.save_config_callback is not None:
|
||||
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
|
||||
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
|
||||
self.config_init['trainer']['callbacks'].append(config_callback)
|
||||
self.trainer = self.trainer_class(**self.config_init['trainer'])
|
||||
|
|
|
@ -328,6 +328,31 @@ def test_lightning_cli_args(tmpdir):
|
|||
assert config['trainer'] == cli.config['trainer']
|
||||
|
||||
|
||||
def test_lightning_cli_save_config_cases(tmpdir):
|
||||
|
||||
config_path = tmpdir / 'config.yaml'
|
||||
cli_args = [
|
||||
f'--trainer.default_root_dir={tmpdir}',
|
||||
'--trainer.logger=False',
|
||||
'--trainer.fast_dev_run=1',
|
||||
]
|
||||
|
||||
# With fast_dev_run!=False config should not be saved
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args):
|
||||
LightningCLI(BoringModel)
|
||||
assert not os.path.isfile(config_path)
|
||||
|
||||
# With fast_dev_run==False config should be saved
|
||||
cli_args[-1] = '--trainer.max_epochs=1'
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args):
|
||||
LightningCLI(BoringModel)
|
||||
assert os.path.isfile(config_path)
|
||||
|
||||
# If run again on same directory exception should be raised since config file already exists
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError):
|
||||
LightningCLI(BoringModel)
|
||||
|
||||
|
||||
def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
||||
|
||||
config = dict(
|
||||
|
|
Loading…
Reference in New Issue