2023-04-28 09:42:45 +00:00
|
|
|
import torch
|
|
|
|
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
|
|
|
|
|
|
|
|
|
|
|
def test_param_in_hook_signature():
|
|
|
|
class LightningModule:
|
2023-09-05 23:24:00 +00:00
|
|
|
def validation_step(self, dataloader_iter):
|
2023-04-28 09:42:45 +00:00
|
|
|
...
|
|
|
|
|
|
|
|
model = LightningModule()
|
|
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
|
|
|
|
|
|
class LightningModule:
|
|
|
|
@torch.no_grad()
|
2023-09-05 23:24:00 +00:00
|
|
|
def validation_step(self, dataloader_iter):
|
2023-04-28 09:42:45 +00:00
|
|
|
...
|
|
|
|
|
|
|
|
model = LightningModule()
|
|
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
|
|
|
|
|
|
class LightningModule:
|
|
|
|
def validation_step(self, *args):
|
|
|
|
...
|
|
|
|
|
|
|
|
model = LightningModule()
|
|
|
|
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False)
|
|
|
|
|
|
|
|
class LightningModule:
|
|
|
|
def validation_step(self, a, b):
|
|
|
|
...
|
|
|
|
|
|
|
|
model = LightningModule()
|
|
|
|
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3)
|
|
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)
|