diff --git a/src/lightning/pytorch/demos/__init__.py b/src/lightning/pytorch/demos/__init__.py index fa91d7cac9..1e03e2fdfd 100644 --- a/src/lightning/pytorch/demos/__init__.py +++ b/src/lightning/pytorch/demos/__init__.py @@ -1,2 +1,15 @@ -from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401 -from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401 +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel +from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM +from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 + +__all__ = [ + "LightningLSTM", + "SequenceSampler", + "SimpleLSTM", + "LightningTransformer", + "Transformer", + "WikiText2", + "BoringModel", + "BoringDataModule", + "DemoModel", +] diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 4fc836c764..d106a05e80 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -478,8 +478,8 @@ def test_lightning_cli_print_config(): "any.py", "predict", "--seed_everything=1234", - "--model=lightning.pytorch.demos.boring_classes.BoringModel", - "--data=lightning.pytorch.demos.boring_classes.BoringDataModule", + "--model=lightning.pytorch.demos.BoringModel", + "--data=lightning.pytorch.demos.BoringDataModule", "--print_config", ] out = StringIO() @@ -492,8 +492,8 @@ def test_lightning_cli_print_config(): outval = yaml.safe_load(text) assert outval["seed_everything"] == 1234 - assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel" - assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule" + assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel" + assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule" assert outval["ckpt_path"] is None