diff --git a/.gitignore b/.gitignore index 34bed65f66..e39440a309 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ ENV/ # mypy .mypy_cache/ + +# data +mnist/ \ No newline at end of file diff --git a/pytorch_lightning/pt_overrides/__init__.py b/pytorch_lightning/pt_overrides/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py new file mode 100644 index 0000000000..5e0150ee5b --- /dev/null +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -0,0 +1,31 @@ +from itertools import chain +from torch.nn import DataParallel + + +class LightningDataParallel(DataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + # ------------- + # MAIN CHANGE + if self.module.training: + return self.module.training_step(*inputs, **kwargs) + else: + return self.module.validation_step(*inputs, **kwargs) + # ------------- + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device)