diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 52dcb4be26..479456ccea 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -320,7 +320,7 @@ class Trainer(TrainerIO): def single_gpu_train(self, model): torch.cuda.set_device(0) - model = model.cuda(0) + model.cuda(0) # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus