Copy batch for local forward (#532)
This commit is contained in:
parent
55edf7c922
commit
48b797fdb0
|
@ -290,7 +290,7 @@ class TrainerTrainLoopMixin(object):
|
|||
gpu_id = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
batch = self.transfer_batch_to_gpu(batch, gpu_id)
|
||||
batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id)
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue