removed self.model refs
This commit is contained in:
parent
df4ac681ed
commit
bf0f5a5cbb
|
@ -2,7 +2,7 @@ import torch
|
|||
import os
|
||||
import re
|
||||
import pdb
|
||||
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||
|
||||
class ModelIO(object):
|
||||
|
||||
|
@ -49,7 +49,8 @@ class TrainerIO(object):
|
|||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# request what to save from the model
|
||||
checkpoint_dict = self.model.get_save_dict()
|
||||
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||
checkpoint_dict = model.get_save_dict()
|
||||
|
||||
# merge trainer and model saving items
|
||||
checkpoint.update(checkpoint_dict)
|
||||
|
@ -130,7 +131,8 @@ class TrainerIO(object):
|
|||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
self.model.load_model_specific(checkpoint)
|
||||
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||
model.load_model_specific(checkpoint)
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
files = os.listdir(path)
|
||||
|
|
Loading…
Reference in New Issue