diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 19c394db48..393650cceb 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -9,7 +9,7 @@ Trainer :exclude-members: run_pretrain_routine, _abc_impl, - _Trainer__set_random_port, + _Trainer_set_random_port, _Trainer__set_root_gpu, _Trainer__init_optimizers, _Trainer__parse_gpu_ids, diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 3b94c9ae3b..68629dc8c1 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -369,7 +369,7 @@ class TrainerDDPMixin(ABC): # don't make this debug... this is good UX rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') - def __set_random_port(self): + def set_random_port(self): """ When running DDP NOT managed by SLURM, the ports might collide """ @@ -384,7 +384,6 @@ class TrainerDDPMixin(ABC): os.environ['MASTER_PORT'] = str(default_port) def spawn_ddp_children(self, model): - self.__set_random_port() port = os.environ['MASTER_PORT'] master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35027cf3c1..cd77f1caa8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -897,18 +897,19 @@ class Trainer( self.ddp_train(task, model) elif self.distributed_backend == 'cpu_ddp': - self.__set_random_port() + self._set_random_port self.model = model mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) elif self.distributed_backend == 'ddp_spawn': - self.__set_random_port() + self._set_random_port model.share_memory() # spin up peers mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) elif self.distributed_backend == 'ddp': + self._set_random_port self.spawn_ddp_children(model) # 1 gpu or dp option triggers training using DP module @@ -1273,7 +1274,6 @@ class _PatchDataLoader(object): dataloader: Dataloader object to return when called. """ - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader