diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index dc01f59cd1..9526100721 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -921,8 +921,8 @@ If you want each process to load the full dataset, ignore this warning. if isinstance(batch, torch.Tensor): return batch.cuda(gpu_id) - # when list - elif isinstance(batch, list): + # when list/tuple + elif isinstance(batch, list) or isinstance(batch, tuple): for i, x in enumerate(batch): batch[i] = self.transfer_batch_to_gpu(x, gpu_id) return batch @@ -934,6 +934,9 @@ If you want each process to load the full dataset, ignore this warning. return batch + # nothing matches, return the value as is without transform + return batch + def __tng_forward(self, data_batch, batch_nb, opt_idx): """ Handle forward for each training case (distributed, single gpu, etc...)