Replace GPU device idx with current process index (#1541)

This commit is contained in:
Kevin Chen 2020-04-21 14:29:15 -04:00 committed by GitHub
parent 29c7d2f195
commit bafdeca42f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -327,7 +327,7 @@ class TrainerDDPMixin(ABC):
# MODEL
# copy model to each gpu
if self.on_gpu:
self.root_gpu = self.data_parallel_device_ids[process_idx]
self.root_gpu = process_idx
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)