From 0795e4d51b3734b7fbae840af2550800637b70bf Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 19:46:49 -0400 Subject: [PATCH] updated args --- pytorch_lightning/pt_overrides/override_data_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index bc2280f3d2..8d6b89fe03 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -9,7 +9,6 @@ class LightningDataParallel(DataParallel): """ def forward(self, *inputs, **kwargs): - pdb.set_trace() if not self.device_ids: # ------------- # MAIN CHANGE @@ -27,7 +26,10 @@ class LightningDataParallel(DataParallel): inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: - return self.module(*inputs[0], **kwargs[0]) + if self.module.training: + return self.module.training_step(*inputs[0], **kwargs[0]) + else: + return self.module.validation_step(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) return self.gather(outputs, self.output_device)