From 079fe9bc0908fffdd55a08f7321a199bceaef6f8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 11 Mar 2021 16:49:48 +0530 Subject: [PATCH] Hotfix for torchvision (#6476) --- pl_examples/basic_examples/autoencoder.py | 5 +++-- pl_examples/basic_examples/backbone_image_classifier.py | 5 +++-- pl_examples/basic_examples/dali_image_classifier.py | 5 +++-- pl_examples/basic_examples/mnist_datamodule.py | 3 ++- pl_examples/domain_templates/generative_adversarial_net.py | 5 +++-- tests/helpers/datasets.py | 1 + 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index b3188a21b7..a2010a89f4 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -22,9 +22,10 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 01a5dca0de..3546bee9ad 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -21,9 +21,10 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index b4bf1407a9..da5b1e4fd9 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -31,9 +31,10 @@ from pl_examples import ( cli_lightning_logo, ) -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a50f67cdab..a6d59c64d9 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -20,8 +20,9 @@ from torch.utils.data import DataLoader, random_split from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 285fba8b93..e65ede17da 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -32,9 +32,10 @@ from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cl from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: import torchvision - import torchvision.transforms as transforms + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 5af3fbfbc4..e7bdad0f15 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -69,6 +69,7 @@ class MNIST(Dataset): train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, + **kwargs, ): super().__init__() self.root = root