updated args

This commit is contained in:
William Falcon 2019-06-25 19:00:38 -04:00
parent 0fd4d5e7a1
commit 4d42b1ed5f
1 changed files with 1 additions and 4 deletions

View File

@ -246,10 +246,7 @@ class Trainer(TrainerIO):
# put on gpu if needed
if self.on_gpu:
if self.data_parallel:
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
else:
model = model.cuda()
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
# run tiny validation to make sure program won't crash during val
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)