From ecb68b52f85bf92eaf424c4d3936b21e9ccc0cfc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 13:56:49 -0400 Subject: [PATCH] refactored model tests --- pytorch_lightning/models/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 4b6c151f5f..68ffe8c505 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -458,9 +458,10 @@ class Trainer(TrainerIO): # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() + model.cuda(self.data_parallel_device_ids[0]) + # run through amp wrapper if self.use_amp: - model.cuda(self.data_parallel_device_ids[0]) # An example model, optimizers = amp.initialize(