Fix callback instantiation with CLI subcommands (#9203)

This commit is contained in:
Carlos Mocholí 2021-08-30 17:44:18 +02:00 committed by GitHub
parent c0bd658354
commit 1117e17409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 13 deletions

View File

@ -277,7 +277,7 @@ class LightningCLI:
seed_everything(seed, workers=True)
self.before_instantiate_classes()
self.instantiate_classes()
self.instantiate_classes(self.subcommand)
self.add_configure_optimizers_method_to_model(self.subcommand)
if self.subcommand is not None:
@ -396,12 +396,12 @@ class LightningCLI:
def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
def instantiate_classes(self) -> None:
def instantiate_classes(self, subcommand: Optional[str]) -> None:
"""Instantiates the classes and sets their attributes."""
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self._get(self.config_init, "data")
self.model = self._get(self.config_init, "model")
callbacks = [self._get(self.config_init, c) for c in self.parser.callback_keys]
callbacks = [self._get(self.config_init, c) for c in self._parser(subcommand).callback_keys]
self.trainer = self.instantiate_trainer(self._get(self.config_init, "trainer"), callbacks)
def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
@ -420,6 +420,14 @@ class LightningCLI:
config["callbacks"].append(config_callback)
return self.trainer_class(**config)
def _parser(self, subcommand: Optional[str]) -> ArgumentParser:
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)]
action_subcommand = action_subcommands[0]
return action_subcommand._name_parser_map[subcommand]
def add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""
Adds to the model an automatically generated ``configure_optimizers`` method.
@ -427,13 +435,8 @@ class LightningCLI:
If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
then a `configure_optimizers` method is automatically implemented in the model class.
"""
if subcommand is None:
optimizers_and_lr_schedulers = self.parser.optimizers_and_lr_schedulers
else:
# get the `optimizer_and_lr_schedulers` attribute from the subcommand parser for the subcommand requested
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)][0]
subcommand_parser = action_subcommands._name_parser_map[subcommand]
optimizers_and_lr_schedulers = subcommand_parser.optimizers_and_lr_schedulers
parser = self._parser(subcommand)
optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers
def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
automatic = []

View File

@ -285,15 +285,20 @@ def test_lightning_cli_args_callbacks(tmpdir):
assert cli.trainer.ran_asserts
def test_lightning_cli_configurable_callbacks(tmpdir):
@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_configurable_callbacks(tmpdir, run):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
cli_args = [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
def fit(self, **_):
pass
cli_args = ["fit"] if run else []
cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel, run=False)
cli = MyLightningCLI(BoringModel, run=run)
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1