From dac59bb8d354f28b919d08aa369a7db6ab31bfa6 Mon Sep 17 00:00:00 2001 From: Z ZH Date: Fri, 17 Jan 2020 22:10:05 +0900 Subject: [PATCH] replace obj.copy() with copy.copy(obj) (#701) --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 48bb359713..7c8d0143b8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -151,7 +151,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b """ - +import copy import inspect from abc import ABC, abstractmethod import warnings @@ -586,7 +586,7 @@ class TrainerTrainLoopMixin(ABC): gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] - batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id) + batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) args[0] = batch output = self.model.training_step(*args)