updated args

This commit is contained in:
William Falcon 2019-06-25 20:24:03 -04:00
parent 88606c581f
commit 440f47b864
1 changed files with 4 additions and 1 deletions

View File

@ -196,7 +196,10 @@ class Trainer(TrainerIO):
# -----------------
# RUN VALIDATION STEP
# -----------------
output = model(data_batch, batch_i)
if self.data_parallel:
output = model(data_batch, batch_i)
else:
output = model.validation_step(data_batch, batch_i)
# when DP, we need to aggregate the scalars we received as outputs
# use mean as the reduce function