updated args
This commit is contained in:
parent
0fd4d5e7a1
commit
4d42b1ed5f
|
@ -246,10 +246,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# put on gpu if needed
|
||||
if self.on_gpu:
|
||||
if self.data_parallel:
|
||||
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
else:
|
||||
model = model.cuda()
|
||||
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
|
Loading…
Reference in New Issue