Fix load disparity between normal and hpc (#4526)
* Add missing load functionality in hpc * Add general file load for hpc * Add mark in CHANGELOG * Fix Typo Li**hg**tning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Refactor line separation Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Fix entangled fixation commit * Fix naming of restore_model_states * Fix amp restore place Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
23719e3c05
commit
41c9bee4f0
|
@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
|
||||
- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526))
|
||||
|
||||
## [1.0.5] - 2020-11-03
|
||||
|
||||
|
|
|
@ -103,6 +103,20 @@ class CheckpointConnector:
|
|||
# load model state
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# restore model and datamodule state
|
||||
self.restore_model_state(model, checkpoint)
|
||||
|
||||
if on_gpu:
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
||||
# restore training state
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
def restore_model_state(self, model: LightningModule, checkpoint) -> None:
|
||||
"""
|
||||
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
|
||||
"""
|
||||
|
||||
# give the datamodule a chance to load something
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_load_checkpoint(checkpoint)
|
||||
|
@ -113,18 +127,6 @@ class CheckpointConnector:
|
|||
# restore the state_dict on the model
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
if on_gpu:
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
||||
# restore amp scaling
|
||||
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
|
||||
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
|
||||
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
|
||||
# load training state (affects trainer only)
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
def restore_training_state(self, checkpoint):
|
||||
"""
|
||||
Restore trainer state.
|
||||
|
@ -147,6 +149,12 @@ class CheckpointConnector:
|
|||
" where `model.ckpt` is your checkpoint file."
|
||||
)
|
||||
|
||||
# restore amp scaling
|
||||
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
|
||||
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
|
||||
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
|
||||
# restore callback states
|
||||
self.trainer.on_load_checkpoint(checkpoint)
|
||||
|
||||
|
@ -336,19 +344,13 @@ class CheckpointConnector:
|
|||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
|
||||
|
||||
# load on CPU first
|
||||
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
|
||||
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load model state
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# restore amp scaling
|
||||
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
|
||||
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
|
||||
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
# restore states from 'PyTorch-Lightning checkpoint' dictionary object
|
||||
self.restore_model_state(model, checkpoint)
|
||||
|
||||
if self.trainer.root_gpu is not None:
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
|
Loading…
Reference in New Issue