diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 87951a1c14..6bed638ad5 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -277,7 +277,7 @@ class LightningCLI: seed_everything(seed, workers=True) self.before_instantiate_classes() - self.instantiate_classes() + self.instantiate_classes(self.subcommand) self.add_configure_optimizers_method_to_model(self.subcommand) if self.subcommand is not None: @@ -396,12 +396,12 @@ class LightningCLI: def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" - def instantiate_classes(self) -> None: + def instantiate_classes(self, subcommand: Optional[str]) -> None: """Instantiates the classes and sets their attributes.""" self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self._get(self.config_init, "data") self.model = self._get(self.config_init, "model") - callbacks = [self._get(self.config_init, c) for c in self.parser.callback_keys] + callbacks = [self._get(self.config_init, c) for c in self._parser(subcommand).callback_keys] self.trainer = self.instantiate_trainer(self._get(self.config_init, "trainer"), callbacks) def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: @@ -420,6 +420,14 @@ class LightningCLI: config["callbacks"].append(config_callback) return self.trainer_class(**config) + def _parser(self, subcommand: Optional[str]) -> ArgumentParser: + if subcommand is None: + return self.parser + # return the subcommand parser for the subcommand passed + action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)] + action_subcommand = action_subcommands[0] + return action_subcommand._name_parser_map[subcommand] + def add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: """ Adds to the model an automatically generated ``configure_optimizers`` method. @@ -427,13 +435,8 @@ class LightningCLI: If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. """ - if subcommand is None: - optimizers_and_lr_schedulers = self.parser.optimizers_and_lr_schedulers - else: - # get the `optimizer_and_lr_schedulers` attribute from the subcommand parser for the subcommand requested - action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)][0] - subcommand_parser = action_subcommands._name_parser_map[subcommand] - optimizers_and_lr_schedulers = subcommand_parser.optimizers_and_lr_schedulers + parser = self._parser(subcommand) + optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: automatic = [] diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index f051d150da..0aa35893fc 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -285,15 +285,20 @@ def test_lightning_cli_args_callbacks(tmpdir): assert cli.trainer.ran_asserts -def test_lightning_cli_configurable_callbacks(tmpdir): +@pytest.mark.parametrize("run", (False, True)) +def test_lightning_cli_configurable_callbacks(tmpdir, run): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") - cli_args = [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"] + def fit(self, **_): + pass + + cli_args = ["fit"] if run else [] + cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = MyLightningCLI(BoringModel, run=False) + cli = MyLightningCLI(BoringModel, run=run) callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1