From ea44aac10f970317a8eebc368a686c20a2412a97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 5 May 2023 14:08:18 +0200 Subject: [PATCH] Use PL `is_overridden` in BYOT example (#17070) --- examples/fabric/build_your_own_trainer/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 7ee993039f..7558cfcb05 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -4,7 +4,7 @@ from functools import partial from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union import torch -from lightning_utilities import apply_to_collection, is_overridden +from lightning_utilities import apply_to_collection from tqdm import tqdm import lightning as L @@ -12,6 +12,7 @@ from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects +from lightning.pytorch.utilities.model_helpers import is_overridden class MyCustomTrainer: @@ -274,7 +275,7 @@ class MyCustomTrainer: return # no validation but warning if val_loader was passed, but validation_step not implemented - if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule): + if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)): L.fabric.utilities.rank_zero_warn( "Your LightningModule does not have a validation_step implemented, " "but you passed a validation dataloder. Skipping Validation."