removed self.model refs

This commit is contained in:
William Falcon 2019-06-26 18:12:33 -04:00
parent df4ac681ed
commit bf0f5a5cbb
1 changed files with 5 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import torch
import os
import re
import pdb
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
class ModelIO(object):
@ -49,7 +49,8 @@ class TrainerIO(object):
checkpoint['optimizer_states'] = optimizer_states
# request what to save from the model
checkpoint_dict = self.model.get_save_dict()
model = self.model.module if type(self.model) is LightningDataParallel else self.model
checkpoint_dict = model.get_save_dict()
# merge trainer and model saving items
checkpoint.update(checkpoint_dict)
@ -130,7 +131,8 @@ class TrainerIO(object):
self.restore_training_state(checkpoint)
# load model state
self.model.load_model_specific(checkpoint)
model = self.model.module if type(self.model) is LightningDataParallel else self.model
model.load_model_specific(checkpoint)
def max_ckpt_in_folder(self, path):
files = os.listdir(path)