lightning/tests/tests_pytorch/utilities/test_signature_utils.py

32 lines
1.1 KiB
Python

import torch
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
def test_param_in_hook_signature():
class LightningModule:
def validation_step(self, dataloader_iter): ...
model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
class LightningModule:
@torch.no_grad()
def validation_step(self, dataloader_iter): ...
model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
class LightningModule:
def validation_step(self, *args): ...
model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False)
class LightningModule:
def validation_step(self, a, b): ...
model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)