Updated Fabric trainer example to not call `self.trainer.model` during validation (#19993)

This commit is contained in:
liambsmith 2024-06-21 10:43:30 -04:00 committed by GitHub
parent 5981aebfcc
commit 709a2a9d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 3 deletions

View File

@ -264,7 +264,7 @@ class MyCustomTrainer:
val_loader: Optional[torch.utils.data.DataLoader],
limit_batches: Union[int, float] = float("inf"),
):
"""The validation loop ruunning a single validation epoch.
"""The validation loop running a single validation epoch.
Args:
model: the LightningModule to evaluate
@ -285,7 +285,10 @@ class MyCustomTrainer:
)
return
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
if not is_overridden("on_validation_model_eval", _unwrap_objects(model)):
model.eval()
else:
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
torch.set_grad_enabled(False)
@ -311,7 +314,10 @@ class MyCustomTrainer:
self.fabric.call("on_validation_epoch_end")
self.fabric.call("on_validation_model_train")
if not is_overridden("on_validation_model_train", _unwrap_objects(model)):
model.train()
else:
self.fabric.call("on_validation_model_train")
torch.set_grad_enabled(True)
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor: