Logging Non-matching keys when loading from checkpoint in non-strict … (#8152)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
acb6f26006
commit
1afc1ca7ef
|
@ -202,7 +202,17 @@ class ModelIO(object):
|
||||||
model.on_load_checkpoint(checkpoint)
|
model.on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
# load the state_dict on the model automatically
|
# load the state_dict on the model automatically
|
||||||
model.load_state_dict(checkpoint['state_dict'], strict=strict)
|
keys = model.load_state_dict(checkpoint['state_dict'], strict=strict)
|
||||||
|
|
||||||
|
if not strict:
|
||||||
|
if keys.missing_keys:
|
||||||
|
rank_zero_warn(
|
||||||
|
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
||||||
|
)
|
||||||
|
if keys.unexpected_keys:
|
||||||
|
rank_zero_warn(
|
||||||
|
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue