updated args
This commit is contained in:
parent
117515db48
commit
ac88e3f832
|
@ -79,7 +79,6 @@ class ExampleModel(RootModule):
|
||||||
})
|
})
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def validation_step(self, data_batch, batch_i):
|
def validation_step(self, data_batch, batch_i):
|
||||||
"""
|
"""
|
||||||
Called inside the validation loop
|
Called inside the validation loop
|
||||||
|
|
|
@ -3,6 +3,7 @@ from torch.nn import DataParallel
|
||||||
import threading
|
import threading
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda._utils import _get_device_index
|
from torch.cuda._utils import _get_device_index
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
def get_a_var(obj):
|
def get_a_var(obj):
|
||||||
|
@ -82,6 +83,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
||||||
with lock:
|
with lock:
|
||||||
results[i] = e
|
results[i] = e
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
if len(modules) > 1:
|
if len(modules) > 1:
|
||||||
threads = [threading.Thread(target=_worker,
|
threads = [threading.Thread(target=_worker,
|
||||||
args=(i, module, input, kwargs, device))
|
args=(i, module, input, kwargs, device))
|
||||||
|
|
Loading…
Reference in New Issue