diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index eac8b2785d..efd22fee1d 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -24,7 +24,7 @@ class ModelConnector: self.trainer = trainer def copy_trainer_model_properties(self, model): - ref_model = self._reference_model + ref_model = self._reference_model(model) automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -46,12 +46,11 @@ class ModelConnector: m.local_rank = self.trainer.local_rank def get_model(self): - return self._reference_model + return self._reference_model(self.trainer.model) - @property - def _reference_model(self): + def _reference_model(self, model): if self.trainer.accelerator_backend: - ref_model = self.trainer.accelerator_backend.model + ref_model = self.trainer.accelerator_backend.reference_model(model) else: - ref_model = self.trainer.model + ref_model = model return ref_model