Fix torchtext data to gpu (#4785)

Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Jungwhan 2020-11-24 14:27:14 +09:00 committed by GitHub
parent 7d96fd1168
commit 471ca375ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -113,7 +113,9 @@ def move_data_to_device(batch: Any, device: torch.device):
# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
for field, field_value in data.dataset.fields.items():
if field_value is None:
continue
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data