replace obj.copy() with copy.copy(obj) (#701)
This commit is contained in:
parent
bc67689068
commit
dac59bb8d3
|
@ -151,7 +151,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b
|
|||
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
import warnings
|
||||
|
@ -586,7 +586,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
gpu_id = 0
|
||||
if isinstance(self.data_parallel_device_ids, list):
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id)
|
||||
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue