updated args
This commit is contained in:
parent
117515db48
commit
ac88e3f832
|
@ -79,7 +79,6 @@ class ExampleModel(RootModule):
|
|||
})
|
||||
return output
|
||||
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
"""
|
||||
Called inside the validation loop
|
||||
|
|
|
@ -3,6 +3,7 @@ from torch.nn import DataParallel
|
|||
import threading
|
||||
import torch
|
||||
from torch.cuda._utils import _get_device_index
|
||||
import pdb
|
||||
|
||||
|
||||
def get_a_var(obj):
|
||||
|
@ -82,6 +83,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|||
with lock:
|
||||
results[i] = e
|
||||
|
||||
pdb.set_trace()
|
||||
if len(modules) > 1:
|
||||
threads = [threading.Thread(target=_worker,
|
||||
args=(i, module, input, kwargs, device))
|
||||
|
|
Loading…
Reference in New Issue