diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 48bb359713..7c8d0143b8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -151,7 +151,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b """ - +import copy import inspect from abc import ABC, abstractmethod import warnings @@ -586,7 +586,7 @@ class TrainerTrainLoopMixin(ABC): gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] - batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id) + batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) args[0] = batch output = self.model.training_step(*args)