made fx public (#2247)
* made fx public * made fx public * made fx public
This commit is contained in:
parent
68a1e52292
commit
b7fc092bf4
|
@ -9,7 +9,7 @@ Trainer
|
|||
:exclude-members:
|
||||
run_pretrain_routine,
|
||||
_abc_impl,
|
||||
_Trainer__set_random_port,
|
||||
_Trainer_set_random_port,
|
||||
_Trainer__set_root_gpu,
|
||||
_Trainer__init_optimizers,
|
||||
_Trainer__parse_gpu_ids,
|
||||
|
|
|
@ -369,7 +369,7 @@ class TrainerDDPMixin(ABC):
|
|||
# don't make this debug... this is good UX
|
||||
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
|
||||
|
||||
def __set_random_port(self):
|
||||
def set_random_port(self):
|
||||
"""
|
||||
When running DDP NOT managed by SLURM, the ports might collide
|
||||
"""
|
||||
|
@ -384,7 +384,6 @@ class TrainerDDPMixin(ABC):
|
|||
os.environ['MASTER_PORT'] = str(default_port)
|
||||
|
||||
def spawn_ddp_children(self, model):
|
||||
self.__set_random_port()
|
||||
port = os.environ['MASTER_PORT']
|
||||
|
||||
master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR']
|
||||
|
|
|
@ -897,18 +897,19 @@ class Trainer(
|
|||
self.ddp_train(task, model)
|
||||
|
||||
elif self.distributed_backend == 'cpu_ddp':
|
||||
self.__set_random_port()
|
||||
self._set_random_port
|
||||
self.model = model
|
||||
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
|
||||
|
||||
elif self.distributed_backend == 'ddp_spawn':
|
||||
self.__set_random_port()
|
||||
self._set_random_port
|
||||
model.share_memory()
|
||||
|
||||
# spin up peers
|
||||
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
|
||||
|
||||
elif self.distributed_backend == 'ddp':
|
||||
self._set_random_port
|
||||
self.spawn_ddp_children(model)
|
||||
|
||||
# 1 gpu or dp option triggers training using DP module
|
||||
|
@ -1273,7 +1274,6 @@ class _PatchDataLoader(object):
|
|||
dataloader: Dataloader object to return when called.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||
self.dataloader = dataloader
|
||||
|
||||
|
|
Loading…
Reference in New Issue