Add MNIST dataset & drop torchvision dep. from tests (#986)

* added custom mnist without torchvision dep

* move files so it does not conflict with mnist gitignore

* mock torchvision for tests

* fix line too long

* fix line too long

* fix "module level import not at top of file" warning

* move mock imports to __init__.py

* simplify MNIST a lot and download directly the .pt files

* further simplify and clean up mnist

* revert import overrides

* make as before

* drop  PIL requirement

* move mnist.py to datasets subfolder

* use logging instead of print

* choose same name as in torchvision

* remove torchvision and pillow also from yml file

* refactor if train

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* capitalized class attr

* moved mnist to models

* re-added datsets ignore

* better name for file variable

* Update mnist.py

* move dataset classes to datasets.py

* new line

* update

* update

* fix automerge

* move to base folder

* adapt testingmnist to new mnist base class

* remove temporal fix

* fix datatype

* remove old testingmnist

* readable

* fix import

* fix whitespace

* docstring

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/base/datasets.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* changelog

* added types

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* exist->isfile

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* index -> idx

* temporary fix for trains error

* better changelog message

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2020-03-31 00:25:37 +02:00 committed by GitHub
parent 18d055a390
commit b7de42f70d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 115 additions and 40 deletions

1
.gitignore vendored
View File

@ -17,7 +17,6 @@ docs/source/pl_examples*.rst
docs/source/pytorch_lightning*.rst
docs/source/tests*.rst
docs/source/*.md
tests/tests/
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))
### Fixed

View File

@ -15,7 +15,6 @@ dependencies:
- future>=0.17.1
# For dev and testing
- torchvision>=0.4.0
- tox
- coverage
- codecov
@ -26,7 +25,6 @@ dependencies:
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0
- pip:
- test-tube>=0.7.5

107
tests/base/datasets.py Normal file
View File

@ -0,0 +1,107 @@
import logging
import os
import urllib.request
from typing import Tuple
import torch
from torch import Tensor
from torch.utils.data import Dataset
class MNIST(Dataset):
"""
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
without the torchvision dependency.
Part of the code was copied from
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
RESOURCES = (
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
)
TRAIN_FILE_NAME = 'training.pt'
TEST_FILE_NAME = 'test.pt'
def __init__(self, root: str, train: bool = True, normalize: tuple = (0.5, 1.0), download: bool = False):
super(MNIST, self).__init__()
self.root = root
self.train = train # training set or test set
self.normalize = normalize
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.')
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])
if self.normalize is not None:
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def processed_folder(self) -> str:
return os.path.join(self.root, 'MNIST', 'processed')
def _check_exists(self) -> bool:
train_file = os.path.join(self.processed_folder, self.TRAIN_FILE_NAME)
test_file = os.path.join(self.processed_folder, self.TEST_FILE_NAME)
return os.path.isfile(train_file) and os.path.isfile(test_file)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
if self._check_exists():
return
os.makedirs(self.processed_folder, exist_ok=True)
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(self.processed_folder, os.path.basename(url))
urllib.request.urlretrieve(url, fpath)
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
tensor.sub_(mean).div_(std)
return tensor
class TestingMNIST(MNIST):
def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False, num_samples=8000):
super().__init__(
root,
train=train,
normalize=normalize,
download=download
)
# take just a subset of MNIST dataset
self.data = self.data[:num_samples]
self.targets = self.targets[:num_samples]

View File

@ -1,9 +1,9 @@
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl
from tests.base.datasets import MNIST
# from test_models import assert_ok_test_acc, load_model, \

View File

@ -7,8 +7,8 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tests.base.datasets import TestingMNIST
try:
from test_tube import HyperOptArgumentParser
@ -18,29 +18,6 @@ except ImportError:
from pytorch_lightning.core.lightning import LightningModule
# TODO: remove after getting own MNIST
# TEMPORAL FIX, https://github.com/pytorch/vision/issues/1938
import urllib.request
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
class TestingMNIST(MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, num_samples=8000):
super().__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download
)
# take just a subset of MNIST dataset
self.data = self.data[:num_samples]
self.targets = self.targets[:num_samples]
class DictHparamsModel(LightningModule):
@ -61,8 +38,7 @@ class DictHparamsModel(LightningModule):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True), batch_size=32)
class TestModelBase(LightningModule):
@ -178,17 +154,13 @@ class TestModelBase(LightningModule):
return [optimizer], [scheduler]
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
_ = TestingMNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True, num_samples=2000)
download=True, num_samples=2000)
def _dataloader(self, train):
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = TestingMNIST(root=self.hparams.data_root, train=train,
transform=transform, download=False, num_samples=2000)
download=False, num_samples=2000)
# when using multi-node we need to add the datasampler
batch_size = self.hparams.batch_size

View File

View File

@ -2,7 +2,6 @@
-r ../requirements-extra.txt
# extended list of dependencies dor development and run lint and tests
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
tox
coverage
codecov
@ -11,5 +10,4 @@ pytest-cov
pytest-flake8
flake8
check-manifest
twine==1.13.0
pillow<7.0.0
twine==1.13.0