From 53f1f18442496962ee016552643955e6f8fbfdd4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 13:50:02 -0400 Subject: [PATCH] refactored model tests --- pytorch_lightning/models/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 9dfcc57091..4b6c151f5f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -460,6 +460,8 @@ class Trainer(TrainerIO): # run through amp wrapper if self.use_amp: + model.cuda(self.data_parallel_device_ids[0]) + # An example model, optimizers = amp.initialize( model, self.optimizers, opt_level=self.amp_level,