removed self.model refs
This commit is contained in:
parent
c1cbb1039a
commit
df4ac681ed
|
@ -1,5 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os, shutil
|
import os, shutil
|
||||||
|
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||||
|
|
||||||
|
|
||||||
class Callback(object):
|
class Callback(object):
|
||||||
|
@ -33,6 +34,8 @@ class Callback(object):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
def set_model(self, model):
|
def set_model(self, model):
|
||||||
|
if type(model) is LightningDataParallel:
|
||||||
|
model = model.module
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def on_epoch_begin(self, epoch, logs=None):
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
|
|
Loading…
Reference in New Issue