Fix load disparity between normal and hpc (#4526)

* Add missing load functionality in hpc

* Add general file load for hpc

* Add mark in CHANGELOG

* Fix Typo Li**hg**tning

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Refactor line separation

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Fix entangled fixation commit

* Fix naming of restore_model_states

* Fix amp restore place

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
tarepan 2020-11-10 02:26:38 +09:00 committed by GitHub
parent 23719e3c05
commit 41c9bee4f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 22 deletions

View File

@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526))
## [1.0.5] - 2020-11-03

View File

@ -103,6 +103,20 @@ class CheckpointConnector:
# load model state
model = self.trainer.get_model()
# restore model and datamodule state
self.restore_model_state(model, checkpoint)
if on_gpu:
model.cuda(self.trainer.root_gpu)
# restore training state
self.restore_training_state(checkpoint)
def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
"""
# give the datamodule a chance to load something
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
@ -113,18 +127,6 @@ class CheckpointConnector:
# restore the state_dict on the model
model.load_state_dict(checkpoint['state_dict'])
if on_gpu:
model.cuda(self.trainer.root_gpu)
# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
def restore_training_state(self, checkpoint):
"""
Restore trainer state.
@ -147,6 +149,12 @@ class CheckpointConnector:
" where `model.ckpt` is your checkpoint file."
)
# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# restore callback states
self.trainer.on_load_checkpoint(checkpoint)
@ -336,19 +344,13 @@ class CheckpointConnector:
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
# load on CPU first
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)
# load model state
model = self.trainer.get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# restore states from 'PyTorch-Lightning checkpoint' dictionary object
self.restore_model_state(model, checkpoint)
if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)