Shubhamagarwal92 master (#1349)
* SA: for #958: set torch cuda device when finding root * SA: for #958: removing root gpu hack in trainer/evaluation_loop * SA: setting torch cuda device * comment line too long * check if root gpu exists or available * Incorporating suggestions on #1094 * since root gpu returns none instead of -1 for cpu * undo changes * fixed dp memory thing Co-authored-by: Shubham Agarwal <shubhamagarwal92@gmail.com>
This commit is contained in:
parent
f6a86e8551
commit
16f4cc9ff0
|
@ -526,6 +526,9 @@ class TrainerDPMixin(ABC):
|
||||||
if isinstance(device_ids, int):
|
if isinstance(device_ids, int):
|
||||||
device_ids = list(range(device_ids))
|
device_ids = list(range(device_ids))
|
||||||
|
|
||||||
|
# set dp device
|
||||||
|
torch.cuda.set_device(self.root_gpu)
|
||||||
|
|
||||||
model = LightningDataParallel(model, device_ids=device_ids)
|
model = LightningDataParallel(model, device_ids=device_ids)
|
||||||
|
|
||||||
self.run_pretrain_routine(model)
|
self.run_pretrain_routine(model)
|
||||||
|
|
|
@ -389,6 +389,7 @@ class Trainer(
|
||||||
self.gpus = gpus
|
self.gpus = gpus
|
||||||
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
|
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
|
||||||
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
||||||
|
self.root_device = torch.device("cpu")
|
||||||
|
|
||||||
# tpu state flags
|
# tpu state flags
|
||||||
self.use_tpu = False
|
self.use_tpu = False
|
||||||
|
|
Loading…
Reference in New Issue