diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index a225bf5556..7af01ede05 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -264,7 +264,7 @@ class MyCustomTrainer: val_loader: Optional[torch.utils.data.DataLoader], limit_batches: Union[int, float] = float("inf"), ): - """The validation loop ruunning a single validation epoch. + """The validation loop running a single validation epoch. Args: model: the LightningModule to evaluate @@ -285,7 +285,10 @@ class MyCustomTrainer: ) return - self.fabric.call("on_validation_model_eval") # calls `model.eval()` + if not is_overridden("on_validation_model_eval", _unwrap_objects(model)): + model.eval() + else: + self.fabric.call("on_validation_model_eval") # calls `model.eval()` torch.set_grad_enabled(False) @@ -311,7 +314,10 @@ class MyCustomTrainer: self.fabric.call("on_validation_epoch_end") - self.fabric.call("on_validation_model_train") + if not is_overridden("on_validation_model_train", _unwrap_objects(model)): + model.train() + else: + self.fabric.call("on_validation_model_train") torch.set_grad_enabled(True) def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor: