diff --git a/pytorch_lightning/utils/pt_callbacks.py b/pytorch_lightning/utils/pt_callbacks.py index 6c3888139e..46e6195dfd 100644 --- a/pytorch_lightning/utils/pt_callbacks.py +++ b/pytorch_lightning/utils/pt_callbacks.py @@ -1,5 +1,6 @@ import numpy as np import os, shutil +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel class Callback(object): @@ -33,6 +34,8 @@ class Callback(object): self.params = params def set_model(self, model): + if type(model) is LightningDataParallel: + model = model.module self.model = model def on_epoch_begin(self, epoch, logs=None):