Hotfix for torchvision (#6476)
This commit is contained in:
parent
2ecda5df52
commit
079fe9bc09
|
@ -22,9 +22,10 @@ from torch.utils.data import DataLoader, random_split
|
|||
import pytorch_lightning as pl
|
||||
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
|
||||
|
||||
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.mnist import MNIST
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
from torchvision.datasets import MNIST
|
||||
else:
|
||||
from tests.helpers.datasets import MNIST
|
||||
|
||||
|
|
|
@ -21,9 +21,10 @@ from torch.utils.data import DataLoader, random_split
|
|||
import pytorch_lightning as pl
|
||||
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
|
||||
|
||||
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.mnist import MNIST
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
from torchvision.datasets import MNIST
|
||||
else:
|
||||
from tests.helpers.datasets import MNIST
|
||||
|
||||
|
|
|
@ -31,9 +31,10 @@ from pl_examples import (
|
|||
cli_lightning_logo,
|
||||
)
|
||||
|
||||
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.mnist import MNIST
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
from torchvision.datasets import MNIST
|
||||
else:
|
||||
from tests.helpers.datasets import MNIST
|
||||
|
||||
|
|
|
@ -20,8 +20,9 @@ from torch.utils.data import DataLoader, random_split
|
|||
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE
|
||||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms as transform_lib
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
from torchvision.datasets import MNIST
|
||||
else:
|
||||
from tests.helpers.datasets import MNIST
|
||||
|
|
|
@ -32,9 +32,10 @@ from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cl
|
|||
from pytorch_lightning.core import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
||||
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision import transforms
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
from torchvision.datasets import MNIST
|
||||
else:
|
||||
from tests.helpers.datasets import MNIST
|
||||
|
|
|
@ -69,6 +69,7 @@ class MNIST(Dataset):
|
|||
train: bool = True,
|
||||
normalize: tuple = (0.1307, 0.3081),
|
||||
download: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
|
|
Loading…
Reference in New Issue