removed self.model refs
This commit is contained in:
parent
c1cbb1039a
commit
df4ac681ed
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
import os, shutil
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||
|
||||
|
||||
class Callback(object):
|
||||
|
@ -33,6 +34,8 @@ class Callback(object):
|
|||
self.params = params
|
||||
|
||||
def set_model(self, model):
|
||||
if type(model) is LightningDataParallel:
|
||||
model = model.module
|
||||
self.model = model
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
|
|
Loading…
Reference in New Issue