reduced accelerator selection (#3211)

This commit is contained in:
William Falcon 2020-08-26 21:29:10 -04:00 committed by GitHub
parent 4272360076
commit eb12f58edf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 53 deletions

View File

@ -1021,9 +1021,31 @@ class Trainer(
# set testing if set in environ # set testing if set in environ
self.testing = os.environ.get('PL_TESTING_MODE', self.testing) self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
# ------------------- # choose accelerator
# determine ddp mode self.accelerator_backend = self.select_accelerator()
# -------------------
# setup accelerator
self.accelerator_backend.setup(model)
# train!
results = self.accelerator_backend.train()
# teardown accelerator
self.accelerator_backend.teardown()
# hook
self.call_hook('on_fit_end')
# hook
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')
# return 1 when finished
# used for testing or when we need to know that training succeeded
return results or 1
def select_accelerator(self):
# SLURM ddp # SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
@ -1038,79 +1060,40 @@ class Trainer(
# ------------------- # -------------------
# DDP2 (cluster only) # DDP2 (cluster only)
if self.use_ddp2: if self.use_ddp2:
self.accelerator_backend = DDP2Backend(self) accelerator_backend = DDP2Backend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
elif use_slurm_ddp: elif use_slurm_ddp:
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp') accelerator_backend = DDPBackend(self, mode='slurm_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
elif use_torchelastic_ddp: elif use_torchelastic_ddp:
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
# regular ddp using .spawn # regular ddp using .spawn
elif use_ddp_spawn: elif use_ddp_spawn:
self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes) accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
# ddp # ddp
elif self.distributed_backend == 'ddp': elif self.distributed_backend == 'ddp':
self.accelerator_backend = DDPBackend(self, mode='ddp') accelerator_backend = DDPBackend(self, mode='ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
# dp # dp
elif self.use_dp: elif self.use_dp:
self.accelerator_backend = DataParallelBackend(self) accelerator_backend = DataParallelBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
elif self.use_horovod: elif self.use_horovod:
self.accelerator_backend = HorovodBackend(self) accelerator_backend = HorovodBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
elif self.use_single_gpu: elif self.use_single_gpu:
self.accelerator_backend = GPUBackend(self) accelerator_backend = GPUBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
elif self.use_tpu: elif self.use_tpu:
self.accelerator_backend = TPUBackend(self) accelerator_backend = TPUBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
else: else:
self.accelerator_backend = CPUBackend(self) accelerator_backend = CPUBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
# hook return accelerator_backend
self.call_hook('on_fit_end')
# hook
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')
# return 1 when finished
# used for testing or when we need to know that training succeeded
return results or 1
def can_prepare_data(self): def can_prepare_data(self):
should_call_dm_prepare_data = True should_call_dm_prepare_data = True