updated args

This commit is contained in:
William Falcon 2019-06-25 20:03:27 -04:00
parent 117515db48
commit ac88e3f832
2 changed files with 2 additions and 1 deletions

View File

@ -79,7 +79,6 @@ class ExampleModel(RootModule):
})
return output
def validation_step(self, data_batch, batch_i):
"""
Called inside the validation loop

View File

@ -3,6 +3,7 @@ from torch.nn import DataParallel
import threading
import torch
from torch.cuda._utils import _get_device_index
import pdb
def get_a_var(obj):
@ -82,6 +83,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
with lock:
results[i] = e
pdb.set_trace()
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))