Use PL `is_overridden` in BYOT example (#17070)
This commit is contained in:
parent
5d102bfa4a
commit
ea44aac10f
|
@ -4,7 +4,7 @@ from functools import partial
|
|||
from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection, is_overridden
|
||||
from lightning_utilities import apply_to_collection
|
||||
from tqdm import tqdm
|
||||
|
||||
import lightning as L
|
||||
|
@ -12,6 +12,7 @@ from lightning.fabric.accelerators import Accelerator
|
|||
from lightning.fabric.loggers import Logger
|
||||
from lightning.fabric.strategies import Strategy
|
||||
from lightning.fabric.wrappers import _unwrap_objects
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
|
||||
|
||||
class MyCustomTrainer:
|
||||
|
@ -274,7 +275,7 @@ class MyCustomTrainer:
|
|||
return
|
||||
|
||||
# no validation but warning if val_loader was passed, but validation_step not implemented
|
||||
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule):
|
||||
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)):
|
||||
L.fabric.utilities.rank_zero_warn(
|
||||
"Your LightningModule does not have a validation_step implemented, "
|
||||
"but you passed a validation dataloder. Skipping Validation."
|
||||
|
|
Loading…
Reference in New Issue