removed self.model refs

This commit is contained in:
William Falcon 2019-06-26 18:08:46 -04:00
parent c1cbb1039a
commit df4ac681ed
1 changed files with 3 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import numpy as np
import os, shutil
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
class Callback(object):
@ -33,6 +34,8 @@ class Callback(object):
self.params = params
def set_model(self, model):
if type(model) is LightningDataParallel:
model = model.module
self.model = model
def on_epoch_begin(self, epoch, logs=None):