From 1afc1ca7ef8d069895f4065fb4000b24771585f4 Mon Sep 17 00:00:00 2001 From: karthikrangasai <39360170+karthikrangasai@users.noreply.github.com> Date: Thu, 1 Jul 2021 00:03:13 +0530 Subject: [PATCH] =?UTF-8?q?Logging=20Non-matching=20keys=20when=20loading?= =?UTF-8?q?=20from=20checkpoint=20in=20non-strict=20=E2=80=A6=20(#8152)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: ananthsub Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/core/saving.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ffa9b0a135..74862735ab 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -202,7 +202,17 @@ class ModelIO(object): model.on_load_checkpoint(checkpoint) # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict'], strict=strict) + keys = model.load_state_dict(checkpoint['state_dict'], strict=strict) + + if not strict: + if keys.missing_keys: + rank_zero_warn( + f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}" + ) + if keys.unexpected_keys: + rank_zero_warn( + f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}" + ) return model