From ac88e3f8322ae07605fedf47fb75c8d5ed1cedbb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 20:03:27 -0400 Subject: [PATCH] updated args --- docs/source/examples/example_model.py | 1 - pytorch_lightning/pt_overrides/override_data_parallel.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/examples/example_model.py b/docs/source/examples/example_model.py index 3209794d80..e7caed4bef 100644 --- a/docs/source/examples/example_model.py +++ b/docs/source/examples/example_model.py @@ -79,7 +79,6 @@ class ExampleModel(RootModule): }) return output - def validation_step(self, data_batch, batch_i): """ Called inside the validation loop diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index 55540f9a20..19470106d7 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -3,6 +3,7 @@ from torch.nn import DataParallel import threading import torch from torch.cuda._utils import _get_device_index +import pdb def get_a_var(obj): @@ -82,6 +83,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): with lock: results[i] = e + pdb.set_trace() if len(modules) > 1: threads = [threading.Thread(target=_worker, args=(i, module, input, kwargs, device))