# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import random import time import urllib.request from typing import Optional, Sequence, Tuple import torch from torch import Tensor from torch.utils.data import Dataset class MNIST(Dataset): """ Customized `MNIST `_ dataset for testing Pytorch Lightning without the torchvision dependency. Part of the code was copied from https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py Args: root: Root directory of dataset where ``MNIST/processed/training.pt`` and ``MNIST/processed/test.pt`` exist. train: If ``True``, creates dataset from ``training.pt``, otherwise from ``test.pt``. normalize: mean and std deviation of the MNIST dataset. download: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Examples: >>> dataset = MNIST(".", download=True) >>> len(dataset) 60000 >>> torch.bincount(dataset.targets) tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) """ RESOURCES = ( "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", ) TRAIN_FILE_NAME = 'training.pt' TEST_FILE_NAME = 'test.pt' cache_folder_name = 'complete' def __init__( self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs, ): super().__init__() self.root = root self.train = train # training set or test set self.normalize = normalize self.prepare_data(download) data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) def __getitem__(self, idx: int) -> Tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) if self.normalize is not None and len(self.normalize) == 2: img = self.normalize_tensor(img, *self.normalize) return img, target def __len__(self) -> int: return len(self.data) @property def cached_folder_path(self) -> str: return os.path.join(self.root, 'MNIST', self.cache_folder_name) def _check_exists(self, data_folder: str) -> bool: existing = True for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): existing = existing and os.path.isfile(os.path.join(data_folder, fname)) return existing def prepare_data(self, download: bool = True): if download and not self._check_exists(self.cached_folder_path): self._download(self.cached_folder_path) if not self._check_exists(self.cached_folder_path): raise RuntimeError('Dataset not found.') def _download(self, data_folder: str) -> None: os.makedirs(data_folder, exist_ok=True) for url in self.RESOURCES: logging.info(f'Downloading {url}') fpath = os.path.join(data_folder, os.path.basename(url)) urllib.request.urlretrieve(url, fpath) @staticmethod def _try_load(path_data, trials: int = 30, delta: float = 1.): """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" assert os.path.isfile(path_data), f'missing file: {path_data}' for _ in range(trials): try: res = torch.load(path_data) # todo: specify the possible exception except Exception as e: exception = e time.sleep(delta * random.random()) else: break if exception is not None: # raise the caught exception raise exception return res @staticmethod def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) return tensor.sub(mean).div(std) class TrialMNIST(MNIST): """Constrained MNIST dataset Args: num_samples: number of examples per selected class/digit digits: list selected MNIST digits/classes kwargs: Same as MNIST Examples: >>> dataset = TrialMNIST(".", download=True) >>> len(dataset) 300 >>> sorted(set([d.item() for d in dataset.targets])) [0, 1, 2] >>> torch.bincount(dataset.targets) tensor([100, 100, 100]) """ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): # number of examples per class self.num_samples = num_samples # take just a subset of MNIST dataset self.digits = sorted(digits) if digits else list(range(10)) self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" super().__init__(root, normalize=(0.5, 1.0), **kwargs) @staticmethod def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): classes = {d: 0 for d in digits} indexes = [] for idx, target in enumerate(full_targets): label = target.item() if classes.get(label, float('inf')) >= num_samples: continue indexes.append(idx) classes[label] += 1 if all(classes[k] >= num_samples for k in classes): break data = full_data[indexes] targets = full_targets[indexes] return data, targets def _download(self, data_folder: str) -> None: super()._download(data_folder) for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): path_fname = os.path.join(self.cached_folder_path, fname) assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}' data, targets = self._try_load(path_fname) data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) class AverageDataset(Dataset): def __init__(self, dataset_len=300, sequence_len=100): self.dataset_len = dataset_len self.sequence_len = sequence_len self.input_seq = torch.randn(dataset_len, sequence_len, 10) top, bottom = self.input_seq.chunk(2, -1) self.output_seq = top + bottom.roll(shifts=1, dims=-1) def __len__(self): return self.dataset_len def __getitem__(self, item): return self.input_seq[item], self.output_seq[item] class SklearnDataset(Dataset): def __init__(self, x, y, x_type, y_type): self.x = x self.y = y self._x_type = x_type self._y_type = y_type def __getitem__(self, idx): return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) def __len__(self): return len(self.y)