lightning/tests/base/datasets.py

187 lines
6.6 KiB
Python
Raw Normal View History

Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
import logging
import os
import urllib.request
from typing import Tuple, Optional, Sequence
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
import torch
from torch import Tensor
from torch.utils.data import Dataset
from tests import PACKAGE_ROOT
#: local path to test datasets
PATH_DATASETS = os.path.join(PACKAGE_ROOT, 'Datasets')
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
class MNIST(Dataset):
"""
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
without the torchvision dependency.
Part of the code was copied from
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
Examples:
>>> dataset = MNIST(download=True)
>>> len(dataset)
60000
>>> torch.bincount(dataset.targets)
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
"""
RESOURCES = (
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
)
TRAIN_FILE_NAME = 'training.pt'
TEST_FILE_NAME = 'test.pt'
cache_folder_name = 'complete'
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
normalize: tuple = (0.5, 1.0), download: bool = True):
super().__init__()
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
self.root = root
self.train = train # training set or test set
self.normalize = normalize
self.prepare_data(download)
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
if not self._check_exists(self.cached_folder_path):
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
raise RuntimeError('Dataset not found.')
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file))
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])
if self.normalize is not None:
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def cached_folder_path(self) -> str:
return os.path.join(self.root, 'MNIST', self.cache_folder_name)
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
def _check_exists(self, data_folder: str) -> bool:
existing = True
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
return existing
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
def prepare_data(self, download: bool):
if download:
self._download(self.cached_folder_path)
def _download(self, data_folder: str) -> None:
"""Download the MNIST data if it doesn't exist in cached_folder_path already."""
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
if self._check_exists(data_folder):
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
return
os.makedirs(data_folder, exist_ok=True)
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(data_folder, os.path.basename(url))
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
urllib.request.urlretrieve(url, fpath)
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
tensor.sub_(mean).div_(std)
return tensor
class TrialMNIST(MNIST):
"""Constrain image dataset
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
num_samples: number of examples per selected class/digit
digits: list selected MNIST digits/classes
Examples:
>>> dataset = TrialMNIST(download=True)
>>> len(dataset)
300
>>> sorted(set([d.item() for d in dataset.targets]))
[0, 1, 2]
>>> torch.bincount(dataset.targets)
tensor([100, 100, 100])
"""
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
normalize: tuple = (0.5, 1.0), download: bool = False,
num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2)):
# number of examples per class
self.num_samples = num_samples
# take just a subset of MNIST dataset
self.digits = digits if digits else list(range(10))
self.cache_folder_name = 'digits-' + '-'.join(str(d) for d in sorted(self.digits)) \
+ f'_nb-{self.num_samples}'
Add MNIST dataset & drop torchvision dep. from tests (#986) * added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * exist->isfile Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-30 22:25:37 +00:00
super().__init__(
root,
train=train,
normalize=normalize,
download=download
)
@staticmethod
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor,
num_samples: int, digits: Sequence):
classes = {d: 0 for d in digits}
indexes = []
for idx, target in enumerate(full_targets):
label = target.item()
if classes.get(label, float('inf')) >= num_samples:
continue
indexes.append(idx)
classes[label] += 1
if all(classes[k] >= num_samples for k in classes):
break
data = full_data[indexes]
targets = full_targets[indexes]
return data, targets
def prepare_data(self, download: bool) -> None:
if self._check_exists(self.cached_folder_path):
return
if download:
self._download(super().cached_folder_path)
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
path_fname = os.path.join(super().cached_folder_path, fname)
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
data, targets = torch.load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))