Delete tests.helpers.TrialMNISTDataModule (#5999)
* Remove TrialMNISTDataModule * Allow using TrialMNIST in the MNISTDataModule * Update tests/helpers/datasets.py
This commit is contained in:
parent
d2cd7cb0f9
commit
bfcfac4614
|
@ -11,68 +11,32 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.utilities import _module_available
|
||||
from tests.helpers.datasets import MNIST, SklearnDataset, TrialMNIST
|
||||
|
||||
|
||||
class TrialMNISTDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.non_picklable = None
|
||||
self.checkpoint_state: Optional[str] = None
|
||||
|
||||
def prepare_data(self):
|
||||
TrialMNIST(self.data_dir, train=True, download=True)
|
||||
TrialMNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
if stage == "fit" or stage is None:
|
||||
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
|
||||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True)
|
||||
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)
|
||||
|
||||
self.non_picklable = lambda x: x**2
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=32)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.mnist_val, batch_size=32)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=32)
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
checkpoint[self.__class__.__name__] = self.__class__.__name__
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
|
||||
_SKLEARN_AVAILABLE = _module_available("sklearn")
|
||||
if _SKLEARN_AVAILABLE:
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
class MNISTDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False) -> None:
|
||||
def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dist_sampler = dist_sampler
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
|
||||
# TrialMNIST is a constrained MNIST dataset
|
||||
self.dataset_cls = TrialMNIST if use_trials else MNIST
|
||||
|
||||
# self.dims is returned when you call dm.size()
|
||||
# Setting default dims here because we know them.
|
||||
# Could optionally be assigned dynamically in dm.setup()
|
||||
|
@ -80,31 +44,18 @@ class MNISTDataModule(LightningDataModule):
|
|||
|
||||
def prepare_data(self):
|
||||
# download only
|
||||
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
|
||||
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))
|
||||
self.dataset_cls(self.data_dir, train=True, download=True)
|
||||
self.dataset_cls(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
# TODO: need to split using random_split once updated to torch >= 1.6
|
||||
if stage == "fit" or stage is None:
|
||||
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
self.mnist_train = self.dataset_cls(self.data_dir, train=True)
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))
|
||||
self.mnist_test = self.dataset_cls(self.data_dir, train=False)
|
||||
|
||||
def train_dataloader(self):
|
||||
dist_sampler = None
|
||||
if self.dist_sampler:
|
||||
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)
|
||||
|
||||
return DataLoader(
|
||||
self.mnist_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=dist_sampler,
|
||||
shuffle=False,
|
||||
)
|
||||
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
|
||||
|
|
|
@ -67,7 +67,7 @@ class MNIST(Dataset):
|
|||
self,
|
||||
root: str = PATH_DATASETS,
|
||||
train: bool = True,
|
||||
normalize: tuple = (0.5, 1.0),
|
||||
normalize: tuple = (0.1307, 0.3081),
|
||||
download: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -77,18 +77,15 @@ class MNIST(Dataset):
|
|||
|
||||
self.prepare_data(download)
|
||||
|
||||
if not self._check_exists(self.cached_folder_path):
|
||||
raise RuntimeError('Dataset not found.')
|
||||
|
||||
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
|
||||
self.data, self.targets = _try_load(os.path.join(self.cached_folder_path, data_file))
|
||||
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
|
||||
img = self.data[idx].float().unsqueeze(0)
|
||||
target = int(self.targets[idx])
|
||||
|
||||
if self.normalize is not None:
|
||||
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
|
||||
if self.normalize is not None and len(self.normalize) == 2:
|
||||
img = self.normalize_tensor(img, *self.normalize)
|
||||
|
||||
return img, target
|
||||
|
||||
|
@ -105,67 +102,53 @@ class MNIST(Dataset):
|
|||
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
|
||||
return existing
|
||||
|
||||
def prepare_data(self, download: bool):
|
||||
if download:
|
||||
def prepare_data(self, download: bool = True):
|
||||
if download and not self._check_exists(self.cached_folder_path):
|
||||
self._download(self.cached_folder_path)
|
||||
if not self._check_exists(self.cached_folder_path):
|
||||
raise RuntimeError('Dataset not found.')
|
||||
|
||||
def _download(self, data_folder: str) -> None:
|
||||
"""Download the MNIST data if it doesn't exist in cached_folder_path already."""
|
||||
|
||||
if self._check_exists(data_folder):
|
||||
return
|
||||
|
||||
os.makedirs(data_folder, exist_ok=True)
|
||||
|
||||
os.makedirs(data_folder)
|
||||
for url in self.RESOURCES:
|
||||
logging.info(f'Downloading {url}')
|
||||
fpath = os.path.join(data_folder, os.path.basename(url))
|
||||
urllib.request.urlretrieve(url, fpath)
|
||||
|
||||
@staticmethod
|
||||
def _try_load(path_data, trials: int = 30, delta: float = 1.):
|
||||
"""Resolving loading from the same time from multiple concurrent processes."""
|
||||
res, exception = None, None
|
||||
assert trials, "at least some trial has to be set"
|
||||
assert os.path.isfile(path_data), f'missing file: {path_data}'
|
||||
for _ in range(trials):
|
||||
try:
|
||||
res = torch.load(path_data)
|
||||
# todo: specify the possible exception
|
||||
except Exception as e:
|
||||
exception = e
|
||||
time.sleep(delta * random.random())
|
||||
else:
|
||||
break
|
||||
if exception is not None:
|
||||
# raise the caught exception
|
||||
raise exception
|
||||
return res
|
||||
|
||||
def _try_load(path_data, trials: int = 30, delta: float = 1.):
|
||||
"""Resolving loading from the same time from multiple concurrentprocesses."""
|
||||
res, exp = None, None
|
||||
assert trials, "at least some trial has to be set"
|
||||
assert os.path.isfile(path_data), 'missing file: %s' % path_data
|
||||
for _ in range(trials):
|
||||
try:
|
||||
res = torch.load(path_data)
|
||||
# todo: specify the possible exception
|
||||
except Exception as ex:
|
||||
exp = ex
|
||||
time.sleep(delta * random.random())
|
||||
else:
|
||||
break
|
||||
else:
|
||||
# raise the caught exception if any
|
||||
if exp:
|
||||
raise exp
|
||||
return res
|
||||
|
||||
|
||||
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
|
||||
tensor = tensor.clone()
|
||||
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||
tensor.sub_(mean).div_(std)
|
||||
return tensor
|
||||
@staticmethod
|
||||
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
|
||||
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||
return tensor.sub(mean).div(std)
|
||||
|
||||
|
||||
class TrialMNIST(MNIST):
|
||||
"""Constrain image dataset
|
||||
"""Constrained MNIST dataset
|
||||
|
||||
Args:
|
||||
root: Root directory of dataset where ``MNIST/processed/training.pt``
|
||||
and ``MNIST/processed/test.pt`` exist.
|
||||
train: If ``True``, creates dataset from ``training.pt``,
|
||||
otherwise from ``test.pt``.
|
||||
normalize: mean and std deviation of the MNIST dataset.
|
||||
download: If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
num_samples: number of examples per selected class/digit
|
||||
digits: list selected MNIST digits/classes
|
||||
kwargs: Same as MNIST
|
||||
|
||||
Examples:
|
||||
>>> dataset = TrialMNIST(download=True)
|
||||
|
@ -177,25 +160,15 @@ class TrialMNIST(MNIST):
|
|||
tensor([100, 100, 100])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = PATH_DATASETS,
|
||||
train: bool = True,
|
||||
normalize: tuple = (0.5, 1.0),
|
||||
download: bool = False,
|
||||
num_samples: int = 100,
|
||||
digits: Optional[Sequence] = (0, 1, 2),
|
||||
):
|
||||
|
||||
def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
|
||||
# number of examples per class
|
||||
self.num_samples = num_samples
|
||||
# take just a subset of MNIST dataset
|
||||
self.digits = digits if digits else list(range(10))
|
||||
self.digits = sorted(digits) if digits else list(range(10))
|
||||
|
||||
self.cache_folder_name = 'digits-' + '-'.join(str(d) for d in sorted(self.digits)) \
|
||||
+ f'_nb-{self.num_samples}'
|
||||
self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}"
|
||||
|
||||
super().__init__(root, train=train, normalize=normalize, download=download)
|
||||
super().__init__(normalize=(0.5, 1.0), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
|
||||
|
@ -213,16 +186,12 @@ class TrialMNIST(MNIST):
|
|||
targets = full_targets[indexes]
|
||||
return data, targets
|
||||
|
||||
def prepare_data(self, download: bool) -> None:
|
||||
if self._check_exists(self.cached_folder_path):
|
||||
return
|
||||
if download:
|
||||
self._download(super().cached_folder_path)
|
||||
|
||||
def _download(self, data_folder: str) -> None:
|
||||
super()._download(data_folder)
|
||||
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
|
||||
path_fname = os.path.join(super().cached_folder_path, fname)
|
||||
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
|
||||
data, targets = _try_load(path_fname)
|
||||
path_fname = os.path.join(self.cached_folder_path, fname)
|
||||
assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}'
|
||||
data, targets = self._try_load(path_fname)
|
||||
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
|
||||
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import torch
|
|||
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
|
||||
from tests.helpers.datamodules import TrialMNISTDataModule
|
||||
from tests.helpers.datamodules import MNISTDataModule
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
|
@ -116,7 +116,7 @@ def test_torchscript_retain_training_state():
|
|||
def test_torchscript_properties(tmpdir, modelclass):
|
||||
""" Test that scripted LightningModule has unnecessary methods removed. """
|
||||
model = modelclass()
|
||||
model.datamodule = TrialMNISTDataModule(tmpdir)
|
||||
model.datamodule = MNISTDataModule(tmpdir)
|
||||
script = model.to_torchscript()
|
||||
assert not hasattr(script, "datamodule")
|
||||
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
|
||||
|
|
|
@ -180,7 +180,6 @@ def test_call_to_trainer_method(tmpdir, optimizer):
|
|||
def test_datamodule_parameter(tmpdir):
|
||||
""" Test that the datamodule parameter works """
|
||||
|
||||
# trial datamodule
|
||||
dm = ClassifDataModule()
|
||||
model = ClassificationModel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue