Add MNIST dataset & drop torchvision dep. from tests (#986)
* added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
18d055a390
commit
b7de42f70d
|
@ -17,7 +17,6 @@ docs/source/pl_examples*.rst
|
|||
docs/source/pytorch_lightning*.rst
|
||||
docs/source/tests*.rst
|
||||
docs/source/*.md
|
||||
tests/tests/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Removed
|
||||
|
||||
- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
|
||||
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ dependencies:
|
|||
- future>=0.17.1
|
||||
|
||||
# For dev and testing
|
||||
- torchvision>=0.4.0
|
||||
- tox
|
||||
- coverage
|
||||
- codecov
|
||||
|
@ -26,7 +25,6 @@ dependencies:
|
|||
- autopep8
|
||||
- check-manifest
|
||||
- twine==1.13.0
|
||||
- pillow<7.0.0
|
||||
|
||||
- pip:
|
||||
- test-tube>=0.7.5
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
import logging
|
||||
import os
|
||||
import urllib.request
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MNIST(Dataset):
|
||||
"""
|
||||
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
|
||||
without the torchvision dependency.
|
||||
|
||||
Part of the code was copied from
|
||||
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
|
||||
|
||||
Args:
|
||||
root: Root directory of dataset where ``MNIST/processed/training.pt``
|
||||
and ``MNIST/processed/test.pt`` exist.
|
||||
train: If ``True``, creates dataset from ``training.pt``,
|
||||
otherwise from ``test.pt``.
|
||||
normalize: mean and std deviation of the MNIST dataset.
|
||||
download: If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
|
||||
RESOURCES = (
|
||||
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
|
||||
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
|
||||
)
|
||||
|
||||
TRAIN_FILE_NAME = 'training.pt'
|
||||
TEST_FILE_NAME = 'test.pt'
|
||||
|
||||
def __init__(self, root: str, train: bool = True, normalize: tuple = (0.5, 1.0), download: bool = False):
|
||||
super(MNIST, self).__init__()
|
||||
self.root = root
|
||||
self.train = train # training set or test set
|
||||
self.normalize = normalize
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_exists():
|
||||
raise RuntimeError('Dataset not found.')
|
||||
|
||||
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
|
||||
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
|
||||
img = self.data[idx].float().unsqueeze(0)
|
||||
target = int(self.targets[idx])
|
||||
|
||||
if self.normalize is not None:
|
||||
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@property
|
||||
def processed_folder(self) -> str:
|
||||
return os.path.join(self.root, 'MNIST', 'processed')
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
train_file = os.path.join(self.processed_folder, self.TRAIN_FILE_NAME)
|
||||
test_file = os.path.join(self.processed_folder, self.TEST_FILE_NAME)
|
||||
return os.path.isfile(train_file) and os.path.isfile(test_file)
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the MNIST data if it doesn't exist in processed_folder already."""
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.processed_folder, exist_ok=True)
|
||||
|
||||
for url in self.RESOURCES:
|
||||
logging.info(f'Downloading {url}')
|
||||
fpath = os.path.join(self.processed_folder, os.path.basename(url))
|
||||
urllib.request.urlretrieve(url, fpath)
|
||||
|
||||
|
||||
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
|
||||
tensor = tensor.clone()
|
||||
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||
tensor.sub_(mean).div_(std)
|
||||
return tensor
|
||||
|
||||
|
||||
class TestingMNIST(MNIST):
|
||||
|
||||
def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False, num_samples=8000):
|
||||
super().__init__(
|
||||
root,
|
||||
train=train,
|
||||
normalize=normalize,
|
||||
download=download
|
||||
)
|
||||
# take just a subset of MNIST dataset
|
||||
self.data = self.data[:num_samples]
|
||||
self.targets = self.targets[:num_samples]
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from tests.base.datasets import MNIST
|
||||
|
||||
|
||||
# from test_models import assert_ok_test_acc, load_model, \
|
||||
|
|
|
@ -7,8 +7,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
from tests.base.datasets import TestingMNIST
|
||||
|
||||
try:
|
||||
from test_tube import HyperOptArgumentParser
|
||||
|
@ -18,29 +18,6 @@ except ImportError:
|
|||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
# TODO: remove after getting own MNIST
|
||||
# TEMPORAL FIX, https://github.com/pytorch/vision/issues/1938
|
||||
import urllib.request
|
||||
opener = urllib.request.build_opener()
|
||||
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
|
||||
urllib.request.install_opener(opener)
|
||||
|
||||
|
||||
class TestingMNIST(MNIST):
|
||||
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None,
|
||||
download=False, num_samples=8000):
|
||||
super().__init__(
|
||||
root,
|
||||
train=train,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
download=download
|
||||
)
|
||||
# take just a subset of MNIST dataset
|
||||
self.data = self.data[:num_samples]
|
||||
self.targets = self.targets[:num_samples]
|
||||
|
||||
|
||||
class DictHparamsModel(LightningModule):
|
||||
|
||||
|
@ -61,8 +38,7 @@ class DictHparamsModel(LightningModule):
|
|||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True,
|
||||
transform=transforms.ToTensor()), batch_size=32)
|
||||
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True), batch_size=32)
|
||||
|
||||
|
||||
class TestModelBase(LightningModule):
|
||||
|
@ -178,17 +154,13 @@ class TestModelBase(LightningModule):
|
|||
return [optimizer], [scheduler]
|
||||
|
||||
def prepare_data(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
_ = TestingMNIST(root=self.hparams.data_root, train=True,
|
||||
transform=transform, download=True, num_samples=2000)
|
||||
download=True, num_samples=2000)
|
||||
|
||||
def _dataloader(self, train):
|
||||
# init data generators
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = TestingMNIST(root=self.hparams.data_root, train=train,
|
||||
transform=transform, download=False, num_samples=2000)
|
||||
download=False, num_samples=2000)
|
||||
|
||||
# when using multi-node we need to add the datasampler
|
||||
batch_size = self.hparams.batch_size
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
-r ../requirements-extra.txt
|
||||
|
||||
# extended list of dependencies dor development and run lint and tests
|
||||
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
|
||||
tox
|
||||
coverage
|
||||
codecov
|
||||
|
@ -11,5 +10,4 @@ pytest-cov
|
|||
pytest-flake8
|
||||
flake8
|
||||
check-manifest
|
||||
twine==1.13.0
|
||||
pillow<7.0.0
|
||||
twine==1.13.0
|
Loading…
Reference in New Issue