Revert property as this is incorrect.=
This commit is contained in:
parent
4b16b47843
commit
977625c289
|
@ -24,7 +24,7 @@ class ModelConnector:
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
|
|
||||||
def copy_trainer_model_properties(self, model):
|
def copy_trainer_model_properties(self, model):
|
||||||
ref_model = self._reference_model
|
ref_model = self._reference_model(model)
|
||||||
|
|
||||||
automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
|
automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
|
||||||
self.trainer.train_loop.automatic_optimization = automatic_optimization
|
self.trainer.train_loop.automatic_optimization = automatic_optimization
|
||||||
|
@ -46,12 +46,11 @@ class ModelConnector:
|
||||||
m.local_rank = self.trainer.local_rank
|
m.local_rank = self.trainer.local_rank
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self):
|
||||||
return self._reference_model
|
return self._reference_model(self.trainer.model)
|
||||||
|
|
||||||
@property
|
def _reference_model(self, model):
|
||||||
def _reference_model(self):
|
|
||||||
if self.trainer.accelerator_backend:
|
if self.trainer.accelerator_backend:
|
||||||
ref_model = self.trainer.accelerator_backend.model
|
ref_model = self.trainer.accelerator_backend.reference_model(model)
|
||||||
else:
|
else:
|
||||||
ref_model = self.trainer.model
|
ref_model = model
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
Loading…
Reference in New Issue