DP device fix (#3196)

This commit is contained in:
Philipp Singer 2020-08-27 15:01:29 +02:00 committed by GitHub
parent 4d98419bb8
commit 0aee137ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -47,6 +47,9 @@ class DataParallelBackend(Accelerator):
self.trainer.lr_schedulers = lr_schedulers self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies self.trainer.optimizer_frequencies = optimizer_frequencies
# init torch data parallel
model = self.__init_torch_data_parallel(model)
# hack forward to do autocast for the user # hack forward to do autocast for the user
self.model_autocast_original_forward = model.forward self.model_autocast_original_forward = model.forward
@ -54,9 +57,6 @@ class DataParallelBackend(Accelerator):
if self.trainer.amp_backend: if self.trainer.amp_backend:
model = self.__init_half_precision(model) model = self.__init_half_precision(model)
# init torch data parallel
model = self.__init_torch_data_parallel(model)
self.trainer.model = model self.trainer.model = model
def __init_torch_data_parallel(self, model): def __init_torch_data_parallel(self, model):