updated args

This commit is contained in:
William Falcon 2019-06-25 19:54:28 -04:00
parent c941649532
commit 7814b2d449
1 changed files with 4 additions and 2 deletions

View File

@ -1,6 +1,4 @@
from itertools import chain
from torch.nn import DataParallel
import pdb
import threading
import torch
@ -70,10 +68,14 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
if not isinstance(input, (list, tuple)):
input = (input,)
# ---------------
# CHANGE
if module.training:
return module.training_step(*input, **kwargs)
else:
return module.validation_step(*input, **kwargs)
# ---------------
with lock:
results[i] = output
except Exception as e: