diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 9dfcc57091..4b6c151f5f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -460,6 +460,8 @@ class Trainer(TrainerIO): # run through amp wrapper if self.use_amp: + model.cuda(self.data_parallel_device_ids[0]) + # An example model, optimizers = amp.initialize( model, self.optimizers, opt_level=self.amp_level,