Logging Non-matching keys when loading from checkpoint in non-strict … (#8152)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
acb6f26006
commit
1afc1ca7ef
|
@ -202,7 +202,17 @@ class ModelIO(object):
|
|||
model.on_load_checkpoint(checkpoint)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=strict)
|
||||
keys = model.load_state_dict(checkpoint['state_dict'], strict=strict)
|
||||
|
||||
if not strict:
|
||||
if keys.missing_keys:
|
||||
rank_zero_warn(
|
||||
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
||||
)
|
||||
if keys.unexpected_keys:
|
||||
rank_zero_warn(
|
||||
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue