* fixes #154

* Update trainer.py

* Update trainer.py
This commit is contained in:
William Falcon 2019-08-20 16:59:26 -04:00 committed by GitHub
parent 7119ec1693
commit 55a804b7cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -921,8 +921,8 @@ If you want each process to load the full dataset, ignore this warning.
if isinstance(batch, torch.Tensor):
return batch.cuda(gpu_id)
# when list
elif isinstance(batch, list):
# when list/tuple
elif isinstance(batch, list) or isinstance(batch, tuple):
for i, x in enumerate(batch):
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
return batch
@ -934,6 +934,9 @@ If you want each process to load the full dataset, ignore this warning.
return batch
# nothing matches, return the value as is without transform
return batch
def __tng_forward(self, data_batch, batch_nb, opt_idx):
"""
Handle forward for each training case (distributed, single gpu, etc...)