move `torch.cuda.set_device()` to enable collective calls earlier in setup (#8312)
This commit is contained in:
parent
20df24d2a2
commit
d73c32ab51
|
@ -26,16 +26,19 @@ _log = logging.getLogger(__name__)
|
|||
class GPUAccelerator(Accelerator):
|
||||
""" Accelerator for GPU devices. """
|
||||
|
||||
def setup_environment(self) -> None:
|
||||
super().setup_environment()
|
||||
if "cuda" not in str(self.root_device):
|
||||
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
|
||||
torch.cuda.set_device(self.root_device)
|
||||
|
||||
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If the selected device is not GPU.
|
||||
"""
|
||||
if "cuda" not in str(self.root_device):
|
||||
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
|
||||
self.set_nvidia_flags(trainer.local_rank)
|
||||
torch.cuda.set_device(self.root_device)
|
||||
return super().setup(trainer, model)
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
|
|
|
@ -367,8 +367,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
prepare_for_backward(self.model, closure_loss)
|
||||
|
||||
def model_to_device(self):
|
||||
if self.root_device.type == "cuda":
|
||||
torch.cuda.set_device(self.root_device)
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
|
||||
|
|
|
@ -339,8 +339,6 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
if not self._config_initialized:
|
||||
self._format_config()
|
||||
self._config_initialized = True
|
||||
if self.on_gpu:
|
||||
torch.cuda.set_device(self.root_device)
|
||||
|
||||
def pre_dispatch(self):
|
||||
self.init_deepspeed()
|
||||
|
|
|
@ -118,7 +118,6 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
|||
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available."
|
||||
)
|
||||
super().setup_distributed()
|
||||
torch.cuda.set_device(self.root_device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_sharded_context(self) -> Generator:
|
||||
|
|
|
@ -61,9 +61,6 @@ class SingleDevicePlugin(TrainingTypePlugin):
|
|||
return self.device
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
if self.on_gpu:
|
||||
torch.cuda.set_device(self.root_device)
|
||||
|
||||
self._model.to(self.root_device)
|
||||
|
||||
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
|
Loading…
Reference in New Issue