diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index b2636d3c21..b208efc080 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -246,10 +246,7 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - if self.data_parallel: - model = DataParallel(model, device_ids=self.data_parallel_device_ids) - else: - model = model.cuda() + model = DataParallel(model, device_ids=self.data_parallel_device_ids) # run tiny validation to make sure program won't crash during val _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)