diff --git a/CHANGELOG.md b/CHANGELOG.md index a33aa60fde..60702a2866 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -179,6 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) +- Allow separate config files for parameters with class type when LightningCLI is in subclass_mode=False ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286)) + + ### Deprecated - Deprecated Trainer argument `terminate_on_nan` in favor of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index b6c3b22d7b..9d8cca7db1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -155,7 +155,11 @@ class LightningArgumentParser(ArgumentParser): if subclass_mode: return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) return self.add_class_arguments( - lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer) + lightning_class, + nested_key, + fail_untyped=False, + instantiate=not issubclass(lightning_class, Trainer), + sub_configs=True, ) raise MisconfigurationException( f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " @@ -184,7 +188,7 @@ class LightningArgumentParser(ArgumentParser): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) self.set_choices(nested_key, optimizer_class) else: - self.add_class_arguments(optimizer_class, nested_key, **kwargs) + self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) def add_lr_scheduler_args( @@ -209,7 +213,7 @@ class LightningArgumentParser(ArgumentParser): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) self.set_choices(nested_key, lr_scheduler_class) else: - self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) + self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: