From 65e6687c54937db0f9bdbdb089ac9d288457d6f8 Mon Sep 17 00:00:00 2001 From: s-rog <55400948+s-rog@users.noreply.github.com> Date: Wed, 2 Sep 2020 21:36:42 +0800 Subject: [PATCH] Fix (weights only) checkpoints loading without pl (#3287) * cast pl AttributeDict to dict * fix for omegaconf --- pytorch_lightning/trainer/training_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index f10eb96be1..0f7df23e71 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -370,10 +370,12 @@ class TrainerIOMixin(ABC): if hasattr(model, '_hparams_name'): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # add arguments to the checkpoint - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams if OMEGACONF_AVAILABLE: + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams if isinstance(model.hparams, Container): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + else: + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint)