update (#7056)
This commit is contained in:
parent
8bcd169767
commit
7b0b0d2844
|
@ -372,12 +372,7 @@ class Accelerator:
|
|||
|
||||
def to_device(self, batch: Any) -> Any:
|
||||
"""Pushes the batch to the root device"""
|
||||
# Todo (tchaton) Better fix
|
||||
is_dict = isinstance(batch, dict)
|
||||
if is_dict:
|
||||
batch = [batch]
|
||||
batch = self.batch_to_device(batch, self.root_device)
|
||||
return batch[0] if is_dict else batch
|
||||
return self.batch_to_device(batch, self.root_device)
|
||||
|
||||
@property
|
||||
def amp_backend(self) -> Optional[LightningEnum]:
|
||||
|
|
Loading…
Reference in New Issue