From 2f7a9ad40d2199ea7371bfc53e9a785c50952d14 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:49:01 -0400 Subject: [PATCH] added single gpu train test --- pytorch_lightning/models/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 55f06d53d9..91eb7cc89f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -393,7 +393,9 @@ class Trainer(TrainerIO): output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: - output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i) + gpu_id = self.data_parallel_device_ids[0] + data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] + output = model(data_batch, batch_i) else: output = model.validation_step(data_batch, batch_i) @@ -474,7 +476,7 @@ If you're not using SLURM, ignore this message! self.__dp_train(model) elif self.single_gpu: - self.__single_gpu_train(model)\ + self.__single_gpu_train(model) # ON CPU else: @@ -846,7 +848,10 @@ We recommend you switch to ddp if you want to use amp output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: + gpu_id = self.data_parallel_device_ids[0] + data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) + else: output = self.model.training_step(data_batch, batch_nb)