Fix (weights only) checkpoints loading without pl (#3287)
* cast pl AttributeDict to dict * fix for omegaconf
This commit is contained in:
parent
f747cb6843
commit
65e6687c54
|
@ -370,10 +370,12 @@ class TrainerIOMixin(ABC):
|
|||
if hasattr(model, '_hparams_name'):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||
# add arguments to the checkpoint
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
if OMEGACONF_AVAILABLE:
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
if isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
else:
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
|
||||
|
||||
# give the model a chance to add a few things
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
|
|
Loading…
Reference in New Issue