diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 52406313f8..712a11fd37 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1193,7 +1193,8 @@ class Trainer( self.teardown('test') if self.is_function_implemented('teardown'): - self.model.teardown('test') + model_ref = self.get_model() + model_ref.teardown('test') def check_model_configuration(self, model: LightningModule): r"""