From dec31b3e761115b0a938caf35649acd814f557c9 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 15 Oct 2020 02:56:04 +0530 Subject: [PATCH] Call on_load_checkpoint before loading state_dict (#4057) --- .../trainer/connectors/checkpoint_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index adaac2c827..420cb6d3f5 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -101,15 +101,16 @@ class CheckpointConnector: # load model state model = self.trainer.get_model() - # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict']) - # give the datamodule a chance to load something if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) + # give model a chance to load something model.on_load_checkpoint(checkpoint) + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict']) + if on_gpu: model.cuda(self.trainer.root_gpu)