Fix callback instantiation with CLI subcommands (#9203)
This commit is contained in:
parent
c0bd658354
commit
1117e17409
|
@ -277,7 +277,7 @@ class LightningCLI:
|
|||
seed_everything(seed, workers=True)
|
||||
|
||||
self.before_instantiate_classes()
|
||||
self.instantiate_classes()
|
||||
self.instantiate_classes(self.subcommand)
|
||||
self.add_configure_optimizers_method_to_model(self.subcommand)
|
||||
|
||||
if self.subcommand is not None:
|
||||
|
@ -396,12 +396,12 @@ class LightningCLI:
|
|||
def before_instantiate_classes(self) -> None:
|
||||
"""Implement to run some code before instantiating the classes."""
|
||||
|
||||
def instantiate_classes(self) -> None:
|
||||
def instantiate_classes(self, subcommand: Optional[str]) -> None:
|
||||
"""Instantiates the classes and sets their attributes."""
|
||||
self.config_init = self.parser.instantiate_classes(self.config)
|
||||
self.datamodule = self._get(self.config_init, "data")
|
||||
self.model = self._get(self.config_init, "model")
|
||||
callbacks = [self._get(self.config_init, c) for c in self.parser.callback_keys]
|
||||
callbacks = [self._get(self.config_init, c) for c in self._parser(subcommand).callback_keys]
|
||||
self.trainer = self.instantiate_trainer(self._get(self.config_init, "trainer"), callbacks)
|
||||
|
||||
def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
|
||||
|
@ -420,6 +420,14 @@ class LightningCLI:
|
|||
config["callbacks"].append(config_callback)
|
||||
return self.trainer_class(**config)
|
||||
|
||||
def _parser(self, subcommand: Optional[str]) -> ArgumentParser:
|
||||
if subcommand is None:
|
||||
return self.parser
|
||||
# return the subcommand parser for the subcommand passed
|
||||
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)]
|
||||
action_subcommand = action_subcommands[0]
|
||||
return action_subcommand._name_parser_map[subcommand]
|
||||
|
||||
def add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
"""
|
||||
Adds to the model an automatically generated ``configure_optimizers`` method.
|
||||
|
@ -427,13 +435,8 @@ class LightningCLI:
|
|||
If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
|
||||
then a `configure_optimizers` method is automatically implemented in the model class.
|
||||
"""
|
||||
if subcommand is None:
|
||||
optimizers_and_lr_schedulers = self.parser.optimizers_and_lr_schedulers
|
||||
else:
|
||||
# get the `optimizer_and_lr_schedulers` attribute from the subcommand parser for the subcommand requested
|
||||
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)][0]
|
||||
subcommand_parser = action_subcommands._name_parser_map[subcommand]
|
||||
optimizers_and_lr_schedulers = subcommand_parser.optimizers_and_lr_schedulers
|
||||
parser = self._parser(subcommand)
|
||||
optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers
|
||||
|
||||
def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
|
||||
automatic = []
|
||||
|
|
|
@ -285,15 +285,20 @@ def test_lightning_cli_args_callbacks(tmpdir):
|
|||
assert cli.trainer.ran_asserts
|
||||
|
||||
|
||||
def test_lightning_cli_configurable_callbacks(tmpdir):
|
||||
@pytest.mark.parametrize("run", (False, True))
|
||||
def test_lightning_cli_configurable_callbacks(tmpdir, run):
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
|
||||
|
||||
cli_args = [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
|
||||
def fit(self, **_):
|
||||
pass
|
||||
|
||||
cli_args = ["fit"] if run else []
|
||||
cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = MyLightningCLI(BoringModel, run=False)
|
||||
cli = MyLightningCLI(BoringModel, run=run)
|
||||
|
||||
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
|
||||
assert len(callback) == 1
|
||||
|
|
Loading…
Reference in New Issue