diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 1f8a1ef181..1d1f341586 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -47,6 +47,9 @@ class DataParallelBackend(Accelerator): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies + # init torch data parallel + model = self.__init_torch_data_parallel(model) + # hack forward to do autocast for the user self.model_autocast_original_forward = model.forward @@ -54,9 +57,6 @@ class DataParallelBackend(Accelerator): if self.trainer.amp_backend: model = self.__init_half_precision(model) - # init torch data parallel - model = self.__init_torch_data_parallel(model) - self.trainer.model = model def __init_torch_data_parallel(self, model):