clean up dead code
This commit is contained in:
parent
3ed02e4ed6
commit
522af58504
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import os, shutil
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel
|
||||
|
||||
|
||||
class Callback(object):
|
||||
|
@ -34,7 +34,7 @@ class Callback(object):
|
|||
self.params = params
|
||||
|
||||
def set_model(self, model):
|
||||
if type(model) is LightningDataParallel:
|
||||
if type(model) is LightningDistributedDataParallel:
|
||||
model = model.module
|
||||
self.model = model
|
||||
|
||||
|
|
Loading…
Reference in New Issue