Delete tests.helpers.TrialMNISTDataModule (#5999)

* Remove TrialMNISTDataModule

* Allow using TrialMNIST in the MNISTDataModule

* Update tests/helpers/datasets.py
This commit is contained in:
Carlos Mocholí 2021-02-18 04:35:38 +01:00 committed by GitHub
parent d2cd7cb0f9
commit bfcfac4614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 143 deletions

View File

@ -11,68 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Optional
import torch
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities import _module_available
from tests.helpers.datasets import MNIST, SklearnDataset, TrialMNIST
class TrialMNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape
if stage == "test" or stage is None:
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True)
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)
self.non_picklable = lambda x: x**2
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint[self.__class__.__name__] = self.__class__.__name__
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
_SKLEARN_AVAILABLE = _module_available("sklearn")
if _SKLEARN_AVAILABLE:
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False) -> None:
def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool = False) -> None:
super().__init__()
self.dist_sampler = dist_sampler
self.data_dir = data_dir
self.batch_size = batch_size
# TrialMNIST is a constrained MNIST dataset
self.dataset_cls = TrialMNIST if use_trials else MNIST
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
@ -80,31 +44,18 @@ class MNISTDataModule(LightningDataModule):
def prepare_data(self):
# download only
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))
self.dataset_cls(self.data_dir, train=True, download=True)
self.dataset_cls(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == "fit" or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))
# Assign test dataset for use in dataloader(s)
self.mnist_train = self.dataset_cls(self.data_dir, train=True)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))
self.mnist_test = self.dataset_cls(self.data_dir, train=False)
def train_dataloader(self):
dist_sampler = None
if self.dist_sampler:
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
sampler=dist_sampler,
shuffle=False,
)
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)

View File

@ -67,7 +67,7 @@ class MNIST(Dataset):
self,
root: str = PATH_DATASETS,
train: bool = True,
normalize: tuple = (0.5, 1.0),
normalize: tuple = (0.1307, 0.3081),
download: bool = True,
):
super().__init__()
@ -77,18 +77,15 @@ class MNIST(Dataset):
self.prepare_data(download)
if not self._check_exists(self.cached_folder_path):
raise RuntimeError('Dataset not found.')
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = _try_load(os.path.join(self.cached_folder_path, data_file))
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])
if self.normalize is not None:
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
if self.normalize is not None and len(self.normalize) == 2:
img = self.normalize_tensor(img, *self.normalize)
return img, target
@ -105,67 +102,53 @@ class MNIST(Dataset):
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
return existing
def prepare_data(self, download: bool):
if download:
def prepare_data(self, download: bool = True):
if download and not self._check_exists(self.cached_folder_path):
self._download(self.cached_folder_path)
if not self._check_exists(self.cached_folder_path):
raise RuntimeError('Dataset not found.')
def _download(self, data_folder: str) -> None:
"""Download the MNIST data if it doesn't exist in cached_folder_path already."""
if self._check_exists(data_folder):
return
os.makedirs(data_folder, exist_ok=True)
os.makedirs(data_folder)
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(data_folder, os.path.basename(url))
urllib.request.urlretrieve(url, fpath)
@staticmethod
def _try_load(path_data, trials: int = 30, delta: float = 1.):
"""Resolving loading from the same time from multiple concurrent processes."""
res, exception = None, None
assert trials, "at least some trial has to be set"
assert os.path.isfile(path_data), f'missing file: {path_data}'
for _ in range(trials):
try:
res = torch.load(path_data)
# todo: specify the possible exception
except Exception as e:
exception = e
time.sleep(delta * random.random())
else:
break
if exception is not None:
# raise the caught exception
raise exception
return res
def _try_load(path_data, trials: int = 30, delta: float = 1.):
"""Resolving loading from the same time from multiple concurrentprocesses."""
res, exp = None, None
assert trials, "at least some trial has to be set"
assert os.path.isfile(path_data), 'missing file: %s' % path_data
for _ in range(trials):
try:
res = torch.load(path_data)
# todo: specify the possible exception
except Exception as ex:
exp = ex
time.sleep(delta * random.random())
else:
break
else:
# raise the caught exception if any
if exp:
raise exp
return res
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
tensor.sub_(mean).div_(std)
return tensor
@staticmethod
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
return tensor.sub(mean).div(std)
class TrialMNIST(MNIST):
"""Constrain image dataset
"""Constrained MNIST dataset
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
num_samples: number of examples per selected class/digit
digits: list selected MNIST digits/classes
kwargs: Same as MNIST
Examples:
>>> dataset = TrialMNIST(download=True)
@ -177,25 +160,15 @@ class TrialMNIST(MNIST):
tensor([100, 100, 100])
"""
def __init__(
self,
root: str = PATH_DATASETS,
train: bool = True,
normalize: tuple = (0.5, 1.0),
download: bool = False,
num_samples: int = 100,
digits: Optional[Sequence] = (0, 1, 2),
):
def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
# number of examples per class
self.num_samples = num_samples
# take just a subset of MNIST dataset
self.digits = digits if digits else list(range(10))
self.digits = sorted(digits) if digits else list(range(10))
self.cache_folder_name = 'digits-' + '-'.join(str(d) for d in sorted(self.digits)) \
+ f'_nb-{self.num_samples}'
self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}"
super().__init__(root, train=train, normalize=normalize, download=download)
super().__init__(normalize=(0.5, 1.0), **kwargs)
@staticmethod
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
@ -213,16 +186,12 @@ class TrialMNIST(MNIST):
targets = full_targets[indexes]
return data, targets
def prepare_data(self, download: bool) -> None:
if self._check_exists(self.cached_folder_path):
return
if download:
self._download(super().cached_folder_path)
def _download(self, data_folder: str) -> None:
super()._download(data_folder)
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
path_fname = os.path.join(super().cached_folder_path, fname)
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
data, targets = _try_load(path_fname)
path_fname = os.path.join(self.cached_folder_path, fname)
assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}'
data, targets = self._try_load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))

View File

@ -18,7 +18,7 @@ import torch
from tests.helpers import BoringModel
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests.helpers.datamodules import TrialMNISTDataModule
from tests.helpers.datamodules import MNISTDataModule
@pytest.mark.parametrize("modelclass", [
@ -116,7 +116,7 @@ def test_torchscript_retain_training_state():
def test_torchscript_properties(tmpdir, modelclass):
""" Test that scripted LightningModule has unnecessary methods removed. """
model = modelclass()
model.datamodule = TrialMNISTDataModule(tmpdir)
model.datamodule = MNISTDataModule(tmpdir)
script = model.to_torchscript()
assert not hasattr(script, "datamodule")
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")

View File

@ -180,7 +180,6 @@ def test_call_to_trainer_method(tmpdir, optimizer):
def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """
# trial datamodule
dm = ClassifDataModule()
model = ClassificationModel()