Replace GPU device idx with current process index (#1541)
This commit is contained in:
parent
29c7d2f195
commit
bafdeca42f
|
@ -327,7 +327,7 @@ class TrainerDDPMixin(ABC):
|
|||
# MODEL
|
||||
# copy model to each gpu
|
||||
if self.on_gpu:
|
||||
self.root_gpu = self.data_parallel_device_ids[process_idx]
|
||||
self.root_gpu = process_idx
|
||||
torch.cuda.set_device(self.root_gpu)
|
||||
model.cuda(self.root_gpu)
|
||||
|
||||
|
|
Loading…
Reference in New Issue