Shubhamagarwal92 master (#1349)

* SA: for #958: set torch cuda device when finding root

* SA: for #958: removing root gpu hack in trainer/evaluation_loop

* SA: setting torch cuda device

* comment line too long

* check if root gpu exists or available

* Incorporating suggestions on #1094

* since root gpu returns none instead of -1 for cpu

* undo changes

* fixed dp memory thing

Co-authored-by: Shubham Agarwal <shubhamagarwal92@gmail.com>
This commit is contained in:
William Falcon 2020-04-03 17:56:19 -04:00 committed by GitHub
parent f6a86e8551
commit 16f4cc9ff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -526,6 +526,9 @@ class TrainerDPMixin(ABC):
if isinstance(device_ids, int): if isinstance(device_ids, int):
device_ids = list(range(device_ids)) device_ids = list(range(device_ids))
# set dp device
torch.cuda.set_device(self.root_gpu)
model = LightningDataParallel(model, device_ids=device_ids) model = LightningDataParallel(model, device_ids=device_ids)
self.run_pretrain_routine(model) self.run_pretrain_routine(model)

View File

@ -389,6 +389,7 @@ class Trainer(
self.gpus = gpus self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")
# tpu state flags # tpu state flags
self.use_tpu = False self.use_tpu = False