diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index 98cc0ed871..55540f9a20 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -1,6 +1,4 @@ -from itertools import chain from torch.nn import DataParallel -import pdb import threading import torch @@ -70,10 +68,14 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): if not isinstance(input, (list, tuple)): input = (input,) + # --------------- + # CHANGE if module.training: return module.training_step(*input, **kwargs) else: return module.validation_step(*input, **kwargs) + # --------------- + with lock: results[i] = output except Exception as e: