improve importing demos (#20446)

This commit is contained in:
Jirka Borovec 2024-11-25 23:18:15 +01:00 committed by GitHub
parent 75d7357de2
commit 1f4a77c448
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 6 deletions

View File

@ -1,2 +1,15 @@
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2
__all__ = [
"LightningLSTM",
"SequenceSampler",
"SimpleLSTM",
"LightningTransformer",
"Transformer",
"WikiText2",
"BoringModel",
"BoringDataModule",
"DemoModel",
]

View File

@ -478,8 +478,8 @@ def test_lightning_cli_print_config():
"any.py",
"predict",
"--seed_everything=1234",
"--model=lightning.pytorch.demos.boring_classes.BoringModel",
"--data=lightning.pytorch.demos.boring_classes.BoringDataModule",
"--model=lightning.pytorch.demos.BoringModel",
"--data=lightning.pytorch.demos.BoringDataModule",
"--print_config",
]
out = StringIO()
@ -492,8 +492,8 @@ def test_lightning_cli_print_config():
outval = yaml.safe_load(text)
assert outval["seed_everything"] == 1234
assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel"
assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule"
assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel"
assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule"
assert outval["ckpt_path"] is None