32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import torch
|
|
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
|
|
|
|
|
def test_param_in_hook_signature():
|
|
class LightningModule:
|
|
def validation_step(self, dataloader_iter): ...
|
|
|
|
model = LightningModule()
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
|
|
class LightningModule:
|
|
@torch.no_grad()
|
|
def validation_step(self, dataloader_iter): ...
|
|
|
|
model = LightningModule()
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
|
|
class LightningModule:
|
|
def validation_step(self, *args): ...
|
|
|
|
model = LightningModule()
|
|
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False)
|
|
|
|
class LightningModule:
|
|
def validation_step(self, a, b): ...
|
|
|
|
model = LightningModule()
|
|
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3)
|
|
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)
|