Logging Non-matching keys when loading from checkpoint in non-strict … (#8152)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
karthikrangasai 2021-07-01 00:03:13 +05:30 committed by GitHub
parent acb6f26006
commit 1afc1ca7ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 1 deletions

View File

@ -202,7 +202,17 @@ class ModelIO(object):
model.on_load_checkpoint(checkpoint)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'], strict=strict)
keys = model.load_state_dict(checkpoint['state_dict'], strict=strict)
if not strict:
if keys.missing_keys:
rank_zero_warn(
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
)
if keys.unexpected_keys:
rank_zero_warn(
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
)
return model