added auto restore
This commit is contained in:
parent
a79de1ec8e
commit
d3f19c8321
|
@ -244,6 +244,30 @@ class Trainer(TrainerIO):
|
|||
'''
|
||||
raise ModuleNotFoundError(msg)
|
||||
|
||||
# restore training and model
|
||||
self.restore_state_if_existing_checkpoint()
|
||||
|
||||
def restore_state_if_existing_checkpoint(self):
|
||||
# restore trainer state and model if there is a weight for this experiment
|
||||
last_epoch = -1
|
||||
last_ckpt_name = None
|
||||
|
||||
# find last epoch
|
||||
checkpoints = os.listdir(self.checkpoint_callback.filepath)
|
||||
for name in checkpoints:
|
||||
if '.ckpt' in name:
|
||||
epoch = name.split('epoch_')[1]
|
||||
epoch = re.sub('[^0-9]', '' ,epoch)
|
||||
|
||||
if epoch > last_epoch:
|
||||
last_epoch = epoch
|
||||
last_ckpt_name = name
|
||||
|
||||
# restore last checkpoint
|
||||
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
|
||||
self.restore(last_ckpt_path, self.on_gpu)
|
||||
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self.use_dp or self.use_ddp
|
||||
|
|
|
@ -58,6 +58,25 @@ class TrainerIO(object):
|
|||
# do the actual save
|
||||
torch.save(checkpoint, filepath)
|
||||
|
||||
def restore(self, checkpoint_path, on_gpu):
|
||||
|
||||
if on_gpu:
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load training state (affects trainer only)
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
model = self.__get_model()
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# call model hook
|
||||
model.on_hpc_load(checkpoint)
|
||||
|
||||
def dump_checkpoint(self):
|
||||
|
||||
checkpoint = {
|
||||
|
@ -198,15 +217,15 @@ class TrainerIO(object):
|
|||
# call model hook
|
||||
model.on_hpc_load(checkpoint)
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
|
||||
files = os.listdir(path)
|
||||
files = [x for x in files if 'ckpt_' in x]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return 0
|
||||
|
||||
ckpt_vs = []
|
||||
for name in files:
|
||||
name = name.split('ckpt_')[-1]
|
||||
name = name.split(name_key)[-1]
|
||||
name = re.sub('[^0-9]', '', name)
|
||||
ckpt_vs.append(int(name))
|
||||
|
||||
|
|
Loading…
Reference in New Issue