updated args

This commit is contained in:
William Falcon 2019-06-25 20:03:27 -04:00
parent 117515db48
commit ac88e3f832
2 changed files with 2 additions and 1 deletions

View File

@ -79,7 +79,6 @@ class ExampleModel(RootModule):
}) })
return output return output
def validation_step(self, data_batch, batch_i): def validation_step(self, data_batch, batch_i):
""" """
Called inside the validation loop Called inside the validation loop

View File

@ -3,6 +3,7 @@ from torch.nn import DataParallel
import threading import threading
import torch import torch
from torch.cuda._utils import _get_device_index from torch.cuda._utils import _get_device_index
import pdb
def get_a_var(obj): def get_a_var(obj):
@ -82,6 +83,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
with lock: with lock:
results[i] = e results[i] = e
pdb.set_trace()
if len(modules) > 1: if len(modules) > 1:
threads = [threading.Thread(target=_worker, threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device)) args=(i, module, input, kwargs, device))