diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3045443699..21cfe08b68 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -372,12 +372,7 @@ class Accelerator: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - # Todo (tchaton) Better fix - is_dict = isinstance(batch, dict) - if is_dict: - batch = [batch] - batch = self.batch_to_device(batch, self.root_device) - return batch[0] if is_dict else batch + return self.batch_to_device(batch, self.root_device) @property def amp_backend(self) -> Optional[LightningEnum]: