diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index caa7c2ddd8..24236a6eb4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1021,9 +1021,31 @@ class Trainer( # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ------------------- - # determine ddp mode - # ------------------- + # choose accelerator + self.accelerator_backend = self.select_accelerator() + + # setup accelerator + self.accelerator_backend.setup(model) + + # train! + results = self.accelerator_backend.train() + + # teardown accelerator + self.accelerator_backend.teardown() + + # hook + self.call_hook('on_fit_end') + + # hook + self.teardown('fit') + if self.is_function_implemented('teardown'): + model.teardown('fit') + + # return 1 when finished + # used for testing or when we need to know that training succeeded + return results or 1 + + def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks @@ -1038,79 +1060,40 @@ class Trainer( # ------------------- # DDP2 (cluster only) if self.use_ddp2: - self.accelerator_backend = DDP2Backend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DDP2Backend(self) elif use_slurm_ddp: - self.accelerator_backend = DDPBackend(self, mode='slurm_ddp') - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DDPBackend(self, mode='slurm_ddp') elif use_torchelastic_ddp: - self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') # regular ddp using .spawn elif use_ddp_spawn: - self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes) # ddp elif self.distributed_backend == 'ddp': - self.accelerator_backend = DDPBackend(self, mode='ddp') - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DDPBackend(self, mode='ddp') # dp elif self.use_dp: - self.accelerator_backend = DataParallelBackend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = DataParallelBackend(self) elif self.use_horovod: - self.accelerator_backend = HorovodBackend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = HorovodBackend(self) elif self.use_single_gpu: - self.accelerator_backend = GPUBackend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = GPUBackend(self) elif self.use_tpu: - self.accelerator_backend = TPUBackend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = TPUBackend(self) else: - self.accelerator_backend = CPUBackend(self) - self.accelerator_backend.setup(model) - results = self.accelerator_backend.train() - self.accelerator_backend.teardown() + accelerator_backend = CPUBackend(self) - # hook - self.call_hook('on_fit_end') + return accelerator_backend - # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') - - # return 1 when finished - # used for testing or when we need to know that training succeeded - return results or 1 def can_prepare_data(self): should_call_dm_prepare_data = True