Allow separate config files for parameters with class type when LightningCLI is in subclass_mode=False (#10286)

This commit is contained in:
Mauricio Villegas 2021-11-01 18:24:31 +01:00 committed by GitHub
parent c52d7ba73d
commit 828b5315aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -179,6 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))
- Allow separate config files for parameters with class type when LightningCLI is in subclass_mode=False ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286))
### Deprecated
- Deprecated Trainer argument `terminate_on_nan` in favor of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))

View File

@ -155,7 +155,11 @@ class LightningArgumentParser(ArgumentParser):
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
return self.add_class_arguments(
lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer)
lightning_class,
nested_key,
fail_untyped=False,
instantiate=not issubclass(lightning_class, Trainer),
sub_configs=True,
)
raise MisconfigurationException(
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
@ -184,7 +188,7 @@ class LightningArgumentParser(ArgumentParser):
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
self.set_choices(nested_key, optimizer_class)
else:
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs)
self._optimizers[nested_key] = (optimizer_class, link_to)
def add_lr_scheduler_args(
@ -209,7 +213,7 @@ class LightningArgumentParser(ArgumentParser):
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
self.set_choices(nested_key, lr_scheduler_class)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: