diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 241b240160..434e50609b 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -268,7 +268,6 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - model.cuda() model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) # run tiny validation to make sure program won't crash during val