diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index e2590b5248..2366b882c0 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -1,6 +1,7 @@ from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel import itertools +from itertools import chain import threading import torch @@ -42,6 +43,29 @@ class LightningDataParallel(DataParallel): Override the forward call in lightning so it goes to training and validation step respectively """ + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + # lightning + if self.module.training: + return self.module.training_step(*inputs[0], **kwargs[0]) + else: + return self.module.validation_step(*inputs[0], **kwargs[0]) + + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def parallel_apply(self, replicas, inputs, kwargs): print('LDP') return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])