Call on_load_checkpoint before loading state_dict (#4057)
This commit is contained in:
parent
f967fbba43
commit
dec31b3e76
|
@ -101,15 +101,16 @@ class CheckpointConnector:
|
||||||
# load model state
|
# load model state
|
||||||
model = self.trainer.get_model()
|
model = self.trainer.get_model()
|
||||||
|
|
||||||
# load the state_dict on the model automatically
|
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
|
||||||
|
|
||||||
# give the datamodule a chance to load something
|
# give the datamodule a chance to load something
|
||||||
if self.trainer.datamodule is not None:
|
if self.trainer.datamodule is not None:
|
||||||
self.trainer.datamodule.on_load_checkpoint(checkpoint)
|
self.trainer.datamodule.on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
# give model a chance to load something
|
# give model a chance to load something
|
||||||
model.on_load_checkpoint(checkpoint)
|
model.on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
# load the state_dict on the model automatically
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
|
||||||
if on_gpu:
|
if on_gpu:
|
||||||
model.cuda(self.trainer.root_gpu)
|
model.cuda(self.trainer.root_gpu)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue