replace obj.copy() with copy.copy(obj) (#701)

This commit is contained in:
Z ZH 2020-01-17 22:10:05 +09:00 committed by William Falcon
parent bc67689068
commit dac59bb8d3
1 changed files with 2 additions and 2 deletions

View File

@ -151,7 +151,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b
"""
import copy
import inspect
from abc import ABC, abstractmethod
import warnings
@ -586,7 +586,7 @@ class TrainerTrainLoopMixin(ABC):
gpu_id = 0
if isinstance(self.data_parallel_device_ids, list):
gpu_id = self.data_parallel_device_ids[0]
batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id)
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
args[0] = batch
output = self.model.training_step(*args)