Fix torchtext data to gpu (#4785)
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
7d96fd1168
commit
471ca375ba
|
@ -113,7 +113,9 @@ def move_data_to_device(batch: Any, device: torch.device):
|
|||
|
||||
# Shallow copy because each Batch has a reference to Dataset which contains all examples
|
||||
device_data = copy(data)
|
||||
for field in data.fields:
|
||||
for field, field_value in data.dataset.fields.items():
|
||||
if field_value is None:
|
||||
continue
|
||||
device_field = move_data_to_device(getattr(data, field), device)
|
||||
setattr(device_data, field, device_field)
|
||||
return device_data
|
||||
|
|
Loading…
Reference in New Issue