diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index 5e0150ee5b..bc2280f3d2 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -1,5 +1,6 @@ from itertools import chain from torch.nn import DataParallel +import pdb class LightningDataParallel(DataParallel): @@ -8,6 +9,7 @@ class LightningDataParallel(DataParallel): """ def forward(self, *inputs, **kwargs): + pdb.set_trace() if not self.device_ids: # ------------- # MAIN CHANGE