diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index 150ac309dd..f22cb5b8e4 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -14,7 +14,6 @@ _EXAMPLES_ROOT = os.path.dirname(__file__) _PACKAGE_ROOT = os.path.dirname(_EXAMPLES_ROOT) _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') -_TORCHVISION_AVAILABLE = _module_available("torchvision") _TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) _DALI_AVAILABLE = _module_available("nvidia.dali") diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 6841b8555e..d71af97cc1 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -20,7 +20,8 @@ from torch import nn from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 1c78d264a8..b727514ad7 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -19,7 +19,8 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 08bf64da25..30ad76f58c 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -23,13 +23,8 @@ from torch.nn import functional as F from torch.utils.data import random_split import pytorch_lightning as pl -from pl_examples import ( - _DALI_AVAILABLE, - _DATASETS_PATH, - _TORCHVISION_AVAILABLE, - _TORCHVISION_MNIST_AVAILABLE, - cli_lightning_logo, -) +from pl_examples import _DALI_AVAILABLE, _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index ea64f96c05..ffb507a940 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -17,8 +17,9 @@ from warnings import warn from torch.utils.data import DataLoader, random_split -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE +from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 29fcf97de8..19bce65746 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -28,9 +28,10 @@ import torch.nn as nn import torch.nn.functional as F # noqa from torch.utils.data import DataLoader -from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pl_examples import _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision