2022-06-30 20:43:04 +00:00
|
|
|
import os
|
|
|
|
from numbers import Rational
|
|
|
|
|
2023-02-01 11:07:00 +00:00
|
|
|
from lightning.app import LightningApp, LightningFlow
|
|
|
|
from lightning.app.testing.helpers import _RunIf
|
|
|
|
from lightning.app.utilities.imports import _is_pytorch_lightning_available
|
|
|
|
from lightning.app.utilities.introspection import Scanner
|
2022-06-30 20:43:04 +00:00
|
|
|
|
|
|
|
if _is_pytorch_lightning_available():
|
|
|
|
from pytorch_lightning import Trainer
|
2022-07-23 12:07:29 +00:00
|
|
|
from pytorch_lightning.cli import LightningCLI
|
2022-06-30 20:43:04 +00:00
|
|
|
|
|
|
|
from tests_app import _PROJECT_ROOT
|
|
|
|
|
|
|
|
|
|
|
|
def test_introspection():
|
|
|
|
"""This test validates the scanner can find some class within the provided files."""
|
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/example_1.py")))
|
|
|
|
assert scanner.has_class(Rational)
|
|
|
|
assert not scanner.has_class(LightningApp)
|
|
|
|
|
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/example_2.py")))
|
|
|
|
assert scanner.has_class(LightningApp)
|
|
|
|
assert not scanner.has_class(LightningFlow)
|
|
|
|
|
|
|
|
|
2022-10-28 13:57:35 +00:00
|
|
|
@_RunIf(pl=True)
|
2022-06-30 20:43:04 +00:00
|
|
|
def test_introspection_lightning():
|
|
|
|
"""This test validates the scanner can find some PyTorch Lightning class within the provided files."""
|
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py")))
|
|
|
|
assert not scanner.has_class(Trainer)
|
|
|
|
assert scanner.has_class(LightningCLI)
|
|
|
|
|
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_trainer.py")))
|
|
|
|
assert scanner.has_class(Trainer)
|
|
|
|
assert not scanner.has_class(LightningCLI)
|
|
|
|
|
|
|
|
|
2022-10-28 13:57:35 +00:00
|
|
|
@_RunIf(pl=True)
|
2022-06-30 20:43:04 +00:00
|
|
|
def test_introspection_lightning_overrides():
|
|
|
|
"""This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning
|
|
|
|
in the provided files."""
|
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py")))
|
2022-12-15 10:02:33 +00:00
|
|
|
scan = scanner.scan()
|
|
|
|
assert set(scan) == {"LightningDataModule", "LightningModule"}
|
|
|
|
|
2022-06-30 20:43:04 +00:00
|
|
|
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_overrides.py")))
|
|
|
|
scan = scanner.scan()
|
2022-12-15 10:02:33 +00:00
|
|
|
assert set(scan) == {
|
2022-06-30 20:43:04 +00:00
|
|
|
"Accelerator",
|
2022-12-15 10:02:33 +00:00
|
|
|
"Profiler",
|
2022-06-30 20:43:04 +00:00
|
|
|
"Callback",
|
|
|
|
"LightningDataModule",
|
2023-01-04 15:57:18 +00:00
|
|
|
"Fabric",
|
2022-12-15 10:02:33 +00:00
|
|
|
"Logger",
|
2022-06-30 20:43:04 +00:00
|
|
|
"LightningModule",
|
|
|
|
"Metric",
|
|
|
|
"PrecisionPlugin",
|
|
|
|
"Trainer",
|
2022-12-15 10:02:33 +00:00
|
|
|
}
|