108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
import logging
|
|
import os
|
|
import urllib.request
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class MNIST(Dataset):
|
|
"""
|
|
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
|
|
without the torchvision dependency.
|
|
|
|
Part of the code was copied from
|
|
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
|
|
|
|
Args:
|
|
root: Root directory of dataset where ``MNIST/processed/training.pt``
|
|
and ``MNIST/processed/test.pt`` exist.
|
|
train: If ``True``, creates dataset from ``training.pt``,
|
|
otherwise from ``test.pt``.
|
|
normalize: mean and std deviation of the MNIST dataset.
|
|
download: If true, downloads the dataset from the internet and
|
|
puts it in root directory. If dataset is already downloaded, it is not
|
|
downloaded again.
|
|
"""
|
|
|
|
RESOURCES = (
|
|
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
|
|
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
|
|
)
|
|
|
|
TRAIN_FILE_NAME = 'training.pt'
|
|
TEST_FILE_NAME = 'test.pt'
|
|
|
|
def __init__(self, root: str, train: bool = True, normalize: tuple = (0.5, 1.0), download: bool = False):
|
|
super(MNIST, self).__init__()
|
|
self.root = root
|
|
self.train = train # training set or test set
|
|
self.normalize = normalize
|
|
|
|
if download:
|
|
self.download()
|
|
|
|
if not self._check_exists():
|
|
raise RuntimeError('Dataset not found.')
|
|
|
|
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
|
|
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
|
|
img = self.data[idx].float().unsqueeze(0)
|
|
target = int(self.targets[idx])
|
|
|
|
if self.normalize is not None:
|
|
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
|
|
|
|
return img, target
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
@property
|
|
def processed_folder(self) -> str:
|
|
return os.path.join(self.root, 'MNIST', 'processed')
|
|
|
|
def _check_exists(self) -> bool:
|
|
train_file = os.path.join(self.processed_folder, self.TRAIN_FILE_NAME)
|
|
test_file = os.path.join(self.processed_folder, self.TEST_FILE_NAME)
|
|
return os.path.isfile(train_file) and os.path.isfile(test_file)
|
|
|
|
def download(self) -> None:
|
|
"""Download the MNIST data if it doesn't exist in processed_folder already."""
|
|
|
|
if self._check_exists():
|
|
return
|
|
|
|
os.makedirs(self.processed_folder, exist_ok=True)
|
|
|
|
for url in self.RESOURCES:
|
|
logging.info(f'Downloading {url}')
|
|
fpath = os.path.join(self.processed_folder, os.path.basename(url))
|
|
urllib.request.urlretrieve(url, fpath)
|
|
|
|
|
|
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
|
|
tensor = tensor.clone()
|
|
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
|
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
tensor.sub_(mean).div_(std)
|
|
return tensor
|
|
|
|
|
|
class TestingMNIST(MNIST):
|
|
|
|
def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False, num_samples=8000):
|
|
super().__init__(
|
|
root,
|
|
train=train,
|
|
normalize=normalize,
|
|
download=download
|
|
)
|
|
# take just a subset of MNIST dataset
|
|
self.data = self.data[:num_samples]
|
|
self.targets = self.targets[:num_samples]
|