parent
7119ec1693
commit
55a804b7cf
|
@ -921,8 +921,8 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
if isinstance(batch, torch.Tensor):
|
||||
return batch.cuda(gpu_id)
|
||||
|
||||
# when list
|
||||
elif isinstance(batch, list):
|
||||
# when list/tuple
|
||||
elif isinstance(batch, list) or isinstance(batch, tuple):
|
||||
for i, x in enumerate(batch):
|
||||
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
||||
return batch
|
||||
|
@ -934,6 +934,9 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
|
||||
return batch
|
||||
|
||||
# nothing matches, return the value as is without transform
|
||||
return batch
|
||||
|
||||
def __tng_forward(self, data_batch, batch_nb, opt_idx):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
|
|
Loading…
Reference in New Issue