Use PL `is_overridden` in BYOT example (#17070)

This commit is contained in:
Carlos Mocholí 2023-05-05 14:08:18 +02:00 committed by GitHub
parent 5d102bfa4a
commit ea44aac10f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -4,7 +4,7 @@ from functools import partial
from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union
import torch import torch
from lightning_utilities import apply_to_collection, is_overridden from lightning_utilities import apply_to_collection
from tqdm import tqdm from tqdm import tqdm
import lightning as L import lightning as L
@ -12,6 +12,7 @@ from lightning.fabric.accelerators import Accelerator
from lightning.fabric.loggers import Logger from lightning.fabric.loggers import Logger
from lightning.fabric.strategies import Strategy from lightning.fabric.strategies import Strategy
from lightning.fabric.wrappers import _unwrap_objects from lightning.fabric.wrappers import _unwrap_objects
from lightning.pytorch.utilities.model_helpers import is_overridden
class MyCustomTrainer: class MyCustomTrainer:
@ -274,7 +275,7 @@ class MyCustomTrainer:
return return
# no validation but warning if val_loader was passed, but validation_step not implemented # no validation but warning if val_loader was passed, but validation_step not implemented
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule): if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)):
L.fabric.utilities.rank_zero_warn( L.fabric.utilities.rank_zero_warn(
"Your LightningModule does not have a validation_step implemented, " "Your LightningModule does not have a validation_step implemented, "
"but you passed a validation dataloder. Skipping Validation." "but you passed a validation dataloder. Skipping Validation."