From 7b0b0d284494d08e3983321d0cc42fe9e5faeb41 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 16 Apr 2021 21:22:19 +0100 Subject: [PATCH] update (#7056) --- pytorch_lightning/accelerators/accelerator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3045443699..21cfe08b68 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -372,12 +372,7 @@ class Accelerator: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - # Todo (tchaton) Better fix - is_dict = isinstance(batch, dict) - if is_dict: - batch = [batch] - batch = self.batch_to_device(batch, self.root_device) - return batch[0] if is_dict else batch + return self.batch_to_device(batch, self.root_device) @property def amp_backend(self) -> Optional[LightningEnum]: