lightning/tests/tests_pytorch/utilities/test_signature_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
1.2 KiB
Python
Raw Normal View History

import torch
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
def test_param_in_hook_signature():
class LightningModule:
def validation_step(self, dataloader_iter):
...
model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
class LightningModule:
@torch.no_grad()
def validation_step(self, dataloader_iter):
...
model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
class LightningModule:
def validation_step(self, *args):
...
model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False)
class LightningModule:
def validation_step(self, a, b):
...
model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)