From 4d42b1ed5f77224348835e7d22c8eec7d837d5f5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 19:00:38 -0400 Subject: [PATCH] updated args --- pytorch_lightning/models/trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index b2636d3c21..b208efc080 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -246,10 +246,7 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - if self.data_parallel: - model = DataParallel(model, device_ids=self.data_parallel_device_ids) - else: - model = model.cuda() + model = DataParallel(model, device_ids=self.data_parallel_device_ids) # run tiny validation to make sure program won't crash during val _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)