From 879d8799850eb068e3f61cb6c67744a6947cbbec Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 17:27:45 -0400 Subject: [PATCH] fix hparams issue --- pytorch_lightning/trainer/training_io.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 0e9d00c6c4..4bb3c40634 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -325,6 +325,7 @@ class TrainerIOMixin(ABC): checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() if hasattr(model, "hparams"): + self.__clean_namespace(model.hparams) is_namespace = isinstance(model.hparams, Namespace) checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict' @@ -338,6 +339,31 @@ class TrainerIOMixin(ABC): return checkpoint + def __clean_namespace(self, hparams): + """ + Removes all functions from hparams so we can pickle + :param hparams: + :return: + """ + + if isinstance(hparams, Namespace): + del_attrs = [] + for k in hparams.__dict__: + if callable(getattr(hparams, k)): + del_attrs.append(k) + + for k in del_attrs: + delattr(hparams, k) + + elif isinstance(hparams, dict): + del_attrs = [] + for k, v in hparams.items(): + if callable(v): + del_attrs.append(k) + + for k in del_attrs: + del hparams[k] + # -------------------- # HPC IO # --------------------