diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 3f95d04569..caf0f2a198 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -643,7 +643,7 @@ class Trainer(TrainerIO): self.enable_auto_hpc_walltime_manager() # run tiny validation to make sure program won't crash during val - model.on_sanity_check_start() + ref_model.on_sanity_check_start() _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps) # ---------------------------