diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index 634c4ff9de..e2590b5248 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -43,6 +43,7 @@ class LightningDataParallel(DataParallel): """ def parallel_apply(self, replicas, inputs, kwargs): + print('LDP') return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) @@ -55,6 +56,7 @@ class LightningDistributedDataParallel(DistributedDataParallel): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) def forward(self, *inputs, **kwargs): + print('LDDP') self._sync_params() if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)