Updated Fabric trainer example to not call `self.trainer.model` during validation (#19993)
This commit is contained in:
parent
5981aebfcc
commit
709a2a9d3b
|
@ -264,7 +264,7 @@ class MyCustomTrainer:
|
|||
val_loader: Optional[torch.utils.data.DataLoader],
|
||||
limit_batches: Union[int, float] = float("inf"),
|
||||
):
|
||||
"""The validation loop ruunning a single validation epoch.
|
||||
"""The validation loop running a single validation epoch.
|
||||
|
||||
Args:
|
||||
model: the LightningModule to evaluate
|
||||
|
@ -285,7 +285,10 @@ class MyCustomTrainer:
|
|||
)
|
||||
return
|
||||
|
||||
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
|
||||
if not is_overridden("on_validation_model_eval", _unwrap_objects(model)):
|
||||
model.eval()
|
||||
else:
|
||||
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
@ -311,7 +314,10 @@ class MyCustomTrainer:
|
|||
|
||||
self.fabric.call("on_validation_epoch_end")
|
||||
|
||||
self.fabric.call("on_validation_model_train")
|
||||
if not is_overridden("on_validation_model_train", _unwrap_objects(model)):
|
||||
model.train()
|
||||
else:
|
||||
self.fabric.call("on_validation_model_train")
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
|
||||
|
|
Loading…
Reference in New Issue