From 0aee137ba7a2d0c1b7eb01168e9d1611a1546f1f Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Thu, 27 Aug 2020 15:01:29 +0200 Subject: [PATCH] DP device fix (#3196) --- pytorch_lightning/accelerators/dp_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 1f8a1ef181..1d1f341586 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -47,6 +47,9 @@ class DataParallelBackend(Accelerator): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies + # init torch data parallel + model = self.__init_torch_data_parallel(model) + # hack forward to do autocast for the user self.model_autocast_original_forward = model.forward @@ -54,9 +57,6 @@ class DataParallelBackend(Accelerator): if self.trainer.amp_backend: model = self.__init_half_precision(model) - # init torch data parallel - model = self.__init_torch_data_parallel(model) - self.trainer.model = model def __init_torch_data_parallel(self, model):