Set the generator seed in `random_split` usages in the examples (#18651)

This commit is contained in:
Adrian Wälchli 2023-09-27 05:09:05 -07:00 committed by GitHub
parent 4628dfe6fe
commit c631726b5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 8 deletions

View File

@ -89,7 +89,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
self.mnist_test = MNIST(self.data_dir, train=False) self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False) self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True) mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self): def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size) return DataLoader(self.mnist_train, batch_size=self.batch_size)
@ -146,7 +148,9 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
# Assign train/val datasets for use in dataloaders # Assign train/val datasets for use in dataloaders
if stage == "fit": if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s) # Assign test dataset for use in dataloader(s)
if stage == "test": if stage == "test":
@ -230,7 +234,9 @@ There are also data operations you might want to perform on every GPU. Use :meth
# Assign Train/val split(s) for use in Dataloaders # Assign Train/val split(s) for use in Dataloaders
if stage == "fit": if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform) mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign Test split(s) for use in Dataloaders # Assign Test split(s) for use in Dataloaders
if stage == "test": if stage == "test":

View File

@ -63,7 +63,8 @@ def validate(fabric, model, val_dataloader):
def get_dataloaders(dataset): def get_dataloaders(dataset):
n = len(dataset) n = len(dataset)
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000]) generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000], generator=generator)
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True) train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False) val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False) test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)

View File

@ -156,7 +156,9 @@ class MyDataModule(LightningDataModule):
super().__init__() super().__init__()
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) self.mnist_train, self.mnist_val = random_split(
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
self.batch_size = batch_size self.batch_size = batch_size
def train_dataloader(self): def train_dataloader(self):

View File

@ -105,7 +105,9 @@ class MyDataModule(LightningDataModule):
super().__init__() super().__init__()
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) self.mnist_train, self.mnist_val = random_split(
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
self.batch_size = batch_size self.batch_size = batch_size
def train_dataloader(self): def train_dataloader(self):

View File

@ -48,7 +48,9 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
# make assignments here (val/train/test split) # make assignments here (val/train/test split)
# called on every process in DDP # called on every process in DDP
dataset = RandomDataset(1, 100) dataset = RandomDataset(1, 100)
self.train, self.val, self.test = data.random_split(dataset, [80, 10, 10]) self.train, self.val, self.test = data.random_split(
dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self): def train_dataloader(self):
return data.DataLoader(self.train) return data.DataLoader(self.train)

View File

@ -200,7 +200,9 @@ class MNISTDataModule(LightningDataModule):
dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra) dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra)
assert isinstance(dataset, Sized) assert isinstance(dataset, Sized)
train_length = len(dataset) train_length = len(dataset)
self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) self.dataset_train, self.dataset_val = random_split(
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self) -> DataLoader: def train_dataloader(self) -> DataLoader:
"""MNIST train set removes a subset to use for validation.""" """MNIST train set removes a subset to use for validation."""