Use PL `is_overridden` in BYOT example (#17070)

This commit is contained in:
Carlos Mocholí 2023-05-05 14:08:18 +02:00 committed by GitHub
parent 5d102bfa4a
commit ea44aac10f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -4,7 +4,7 @@ from functools import partial
from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union
import torch
from lightning_utilities import apply_to_collection, is_overridden
from lightning_utilities import apply_to_collection
from tqdm import tqdm
import lightning as L
@ -12,6 +12,7 @@ from lightning.fabric.accelerators import Accelerator
from lightning.fabric.loggers import Logger
from lightning.fabric.strategies import Strategy
from lightning.fabric.wrappers import _unwrap_objects
from lightning.pytorch.utilities.model_helpers import is_overridden
class MyCustomTrainer:
@ -274,7 +275,7 @@ class MyCustomTrainer:
return
# no validation but warning if val_loader was passed, but validation_step not implemented
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule):
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)):
L.fabric.utilities.rank_zero_warn(
"Your LightningModule does not have a validation_step implemented, "
"but you passed a validation dataloder. Skipping Validation."