Call on_load_checkpoint before loading state_dict (#4057)

This commit is contained in:
Rohit Gupta 2020-10-15 02:56:04 +05:30 committed by GitHub
parent f967fbba43
commit dec31b3e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -101,15 +101,16 @@ class CheckpointConnector:
# load model state # load model state
model = self.trainer.get_model() model = self.trainer.get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# give the datamodule a chance to load something # give the datamodule a chance to load something
if self.trainer.datamodule is not None: if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint) self.trainer.datamodule.on_load_checkpoint(checkpoint)
# give model a chance to load something # give model a chance to load something
model.on_load_checkpoint(checkpoint) model.on_load_checkpoint(checkpoint)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if on_gpu: if on_gpu:
model.cuda(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu)