255 lines
8.8 KiB
Python
255 lines
8.8 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import platform
|
|
import random
|
|
import time
|
|
import urllib
|
|
from typing import Optional, Tuple
|
|
from urllib.error import HTTPError
|
|
from warnings import warn
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
|
|
from pl_examples import _DATASETS_PATH
|
|
from pytorch_lightning import LightningDataModule
|
|
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
|
|
|
if _TORCHVISION_AVAILABLE:
|
|
from torchvision import transforms as transform_lib
|
|
|
|
|
|
class _MNIST(Dataset):
|
|
"""Carbon copy of ``tests.helpers.datasets.MNIST``.
|
|
|
|
We cannot import the tests as they are not distributed with the package.
|
|
See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
|
|
"""
|
|
|
|
RESOURCES = (
|
|
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
|
|
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
|
|
)
|
|
|
|
TRAIN_FILE_NAME = "training.pt"
|
|
TEST_FILE_NAME = "test.pt"
|
|
cache_folder_name = "complete"
|
|
|
|
def __init__(
|
|
self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
|
|
):
|
|
super().__init__()
|
|
self.root = root
|
|
self.train = train # training set or test set
|
|
self.normalize = normalize
|
|
|
|
self.prepare_data(download)
|
|
|
|
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
|
|
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
img = self.data[idx].float().unsqueeze(0)
|
|
target = int(self.targets[idx])
|
|
|
|
if self.normalize is not None and len(self.normalize) == 2:
|
|
img = self.normalize_tensor(img, *self.normalize)
|
|
|
|
return img, target
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
@property
|
|
def cached_folder_path(self) -> str:
|
|
return os.path.join(self.root, "MNIST", self.cache_folder_name)
|
|
|
|
def _check_exists(self, data_folder: str) -> bool:
|
|
existing = True
|
|
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
|
|
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
|
|
return existing
|
|
|
|
def prepare_data(self, download: bool = True):
|
|
if download and not self._check_exists(self.cached_folder_path):
|
|
self._download(self.cached_folder_path)
|
|
if not self._check_exists(self.cached_folder_path):
|
|
raise RuntimeError("Dataset not found.")
|
|
|
|
def _download(self, data_folder: str) -> None:
|
|
os.makedirs(data_folder, exist_ok=True)
|
|
for url in self.RESOURCES:
|
|
logging.info(f"Downloading {url}")
|
|
fpath = os.path.join(data_folder, os.path.basename(url))
|
|
urllib.request.urlretrieve(url, fpath)
|
|
|
|
@staticmethod
|
|
def _try_load(path_data, trials: int = 30, delta: float = 1.0):
|
|
"""Resolving loading from the same time from multiple concurrent processes."""
|
|
res, exception = None, None
|
|
assert trials, "at least some trial has to be set"
|
|
assert os.path.isfile(path_data), f"missing file: {path_data}"
|
|
for _ in range(trials):
|
|
try:
|
|
res = torch.load(path_data)
|
|
# todo: specify the possible exception
|
|
except Exception as e:
|
|
exception = e
|
|
time.sleep(delta * random.random())
|
|
else:
|
|
break
|
|
if exception is not None:
|
|
# raise the caught exception
|
|
raise exception
|
|
return res
|
|
|
|
@staticmethod
|
|
def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
|
|
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
|
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
return tensor.sub(mean).div(std)
|
|
|
|
|
|
def MNIST(*args, **kwargs):
|
|
torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
|
|
if torchvision_mnist_available:
|
|
try:
|
|
from torchvision.datasets import MNIST
|
|
|
|
MNIST(_DATASETS_PATH, download=True)
|
|
except HTTPError as e:
|
|
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
|
|
torchvision_mnist_available = False
|
|
if not torchvision_mnist_available:
|
|
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
|
|
MNIST = _MNIST
|
|
return MNIST(*args, **kwargs)
|
|
|
|
|
|
class MNISTDataModule(LightningDataModule):
|
|
"""Standard MNIST, train, val, test splits and transforms.
|
|
|
|
>>> MNISTDataModule() # doctest: +ELLIPSIS
|
|
<...mnist_datamodule.MNISTDataModule object at ...>
|
|
"""
|
|
|
|
name = "mnist"
|
|
|
|
def __init__(
|
|
self,
|
|
data_dir: str = _DATASETS_PATH,
|
|
val_split: int = 5000,
|
|
num_workers: int = 16,
|
|
normalize: bool = False,
|
|
seed: int = 42,
|
|
batch_size: int = 32,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
data_dir: where to save/load the data
|
|
val_split: how many of the training images to use for the validation split
|
|
num_workers: how many workers to use for loading data
|
|
normalize: If true applies image normalize
|
|
seed: starting seed for RNG.
|
|
batch_size: desired batch size.
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
if num_workers and platform.system() == "Windows":
|
|
# see: https://stackoverflow.com/a/59680818
|
|
warn(
|
|
f"You have requested num_workers={num_workers} on Windows,"
|
|
" but currently recommended is 0, so we set it for you"
|
|
)
|
|
num_workers = 0
|
|
|
|
self.data_dir = data_dir
|
|
self.val_split = val_split
|
|
self.num_workers = num_workers
|
|
self.normalize = normalize
|
|
self.seed = seed
|
|
self.batch_size = batch_size
|
|
self.dataset_train = ...
|
|
self.dataset_val = ...
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return 10
|
|
|
|
def prepare_data(self):
|
|
"""Saves MNIST files to `data_dir`"""
|
|
MNIST(self.data_dir, train=True, download=True)
|
|
MNIST(self.data_dir, train=False, download=True)
|
|
|
|
def setup(self, stage: Optional[str] = None):
|
|
"""Split the train and valid dataset."""
|
|
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
|
|
dataset = MNIST(self.data_dir, train=True, download=False, **extra)
|
|
train_length = len(dataset)
|
|
self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
|
|
|
|
def train_dataloader(self):
|
|
"""MNIST train set removes a subset to use for validation."""
|
|
loader = DataLoader(
|
|
self.dataset_train,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
num_workers=self.num_workers,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
)
|
|
return loader
|
|
|
|
def val_dataloader(self):
|
|
"""MNIST val set uses a subset of the training set for validation."""
|
|
loader = DataLoader(
|
|
self.dataset_val,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
)
|
|
return loader
|
|
|
|
def test_dataloader(self):
|
|
"""MNIST test set uses the test split."""
|
|
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
|
|
dataset = MNIST(self.data_dir, train=False, download=False, **extra)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
)
|
|
return loader
|
|
|
|
@property
|
|
def default_transforms(self):
|
|
if not _TORCHVISION_AVAILABLE:
|
|
return None
|
|
if self.normalize:
|
|
mnist_transforms = transform_lib.Compose(
|
|
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
|
|
)
|
|
else:
|
|
mnist_transforms = transform_lib.ToTensor()
|
|
|
|
return mnist_transforms
|