Call on_load_checkpoint before loading state_dict (#4057)
This commit is contained in:
parent
f967fbba43
commit
dec31b3e76
|
@ -101,15 +101,16 @@ class CheckpointConnector:
|
|||
# load model state
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# give the datamodule a chance to load something
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_load_checkpoint(checkpoint)
|
||||
|
||||
# give model a chance to load something
|
||||
model.on_load_checkpoint(checkpoint)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
if on_gpu:
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
||||
|
|
Loading…
Reference in New Issue