improve importing demos (#20446)
This commit is contained in:
parent
75d7357de2
commit
1f4a77c448
|
@ -1,2 +1,15 @@
|
|||
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401
|
||||
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401
|
||||
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel
|
||||
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM
|
||||
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2
|
||||
|
||||
__all__ = [
|
||||
"LightningLSTM",
|
||||
"SequenceSampler",
|
||||
"SimpleLSTM",
|
||||
"LightningTransformer",
|
||||
"Transformer",
|
||||
"WikiText2",
|
||||
"BoringModel",
|
||||
"BoringDataModule",
|
||||
"DemoModel",
|
||||
]
|
||||
|
|
|
@ -478,8 +478,8 @@ def test_lightning_cli_print_config():
|
|||
"any.py",
|
||||
"predict",
|
||||
"--seed_everything=1234",
|
||||
"--model=lightning.pytorch.demos.boring_classes.BoringModel",
|
||||
"--data=lightning.pytorch.demos.boring_classes.BoringDataModule",
|
||||
"--model=lightning.pytorch.demos.BoringModel",
|
||||
"--data=lightning.pytorch.demos.BoringDataModule",
|
||||
"--print_config",
|
||||
]
|
||||
out = StringIO()
|
||||
|
@ -492,8 +492,8 @@ def test_lightning_cli_print_config():
|
|||
|
||||
outval = yaml.safe_load(text)
|
||||
assert outval["seed_everything"] == 1234
|
||||
assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel"
|
||||
assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule"
|
||||
assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel"
|
||||
assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule"
|
||||
assert outval["ckpt_path"] is None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue