Fix (weights only) checkpoints loading without pl (#3287)

* cast pl AttributeDict to dict

* fix for omegaconf
This commit is contained in:
s-rog 2020-09-02 21:36:42 +08:00 committed by GitHub
parent f747cb6843
commit 65e6687c54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -370,10 +370,12 @@ class TrainerIOMixin(ABC):
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# add arguments to the checkpoint
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if OMEGACONF_AVAILABLE:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)