move `torch.cuda.set_device()` to enable collective calls earlier in setup (#8312)

This commit is contained in:
Adrian Wälchli 2021-07-07 13:15:41 +02:00 committed by GitHub
parent 20df24d2a2
commit d73c32ab51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 11 deletions

View File

@ -26,16 +26,19 @@ _log = logging.getLogger(__name__)
class GPUAccelerator(Accelerator):
""" Accelerator for GPU devices. """
def setup_environment(self) -> None:
super().setup_environment()
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
torch.cuda.set_device(self.root_device)
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags(trainer.local_rank)
torch.cuda.set_device(self.root_device)
return super().setup(trainer, model)
def on_train_start(self) -> None:

View File

@ -367,8 +367,6 @@ class DDPPlugin(ParallelPlugin):
prepare_for_backward(self.model, closure_loss)
def model_to_device(self):
if self.root_device.type == "cuda":
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:

View File

@ -339,8 +339,6 @@ class DeepSpeedPlugin(DDPPlugin):
if not self._config_initialized:
self._format_config()
self._config_initialized = True
if self.on_gpu:
torch.cuda.set_device(self.root_device)
def pre_dispatch(self):
self.init_deepspeed()

View File

@ -118,7 +118,6 @@ class DDPFullyShardedPlugin(DDPPlugin):
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available."
)
super().setup_distributed()
torch.cuda.set_device(self.root_device)
@contextlib.contextmanager
def model_sharded_context(self) -> Generator:

View File

@ -61,9 +61,6 @@ class SingleDevicePlugin(TrainingTypePlugin):
return self.device
def model_to_device(self) -> None:
if self.on_gpu:
torch.cuda.set_device(self.root_device)
self._model.to(self.root_device)
def setup(self, model: torch.nn.Module) -> torch.nn.Module: