updated args
This commit is contained in:
parent
c941649532
commit
7814b2d449
|
@ -1,6 +1,4 @@
|
|||
from itertools import chain
|
||||
from torch.nn import DataParallel
|
||||
import pdb
|
||||
|
||||
import threading
|
||||
import torch
|
||||
|
@ -70,10 +68,14 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|||
if not isinstance(input, (list, tuple)):
|
||||
input = (input,)
|
||||
|
||||
# ---------------
|
||||
# CHANGE
|
||||
if module.training:
|
||||
return module.training_step(*input, **kwargs)
|
||||
else:
|
||||
return module.validation_step(*input, **kwargs)
|
||||
# ---------------
|
||||
|
||||
with lock:
|
||||
results[i] = output
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue