diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 7ee993039f..7558cfcb05 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -4,7 +4,7 @@ from functools import partial from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union import torch -from lightning_utilities import apply_to_collection, is_overridden +from lightning_utilities import apply_to_collection from tqdm import tqdm import lightning as L @@ -12,6 +12,7 @@ from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects +from lightning.pytorch.utilities.model_helpers import is_overridden class MyCustomTrainer: @@ -274,7 +275,7 @@ class MyCustomTrainer: return # no validation but warning if val_loader was passed, but validation_step not implemented - if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule): + if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)): L.fabric.utilities.rank_zero_warn( "Your LightningModule does not have a validation_step implemented, " "but you passed a validation dataloder. Skipping Validation."