Set the generator seed in `random_split` usages in the examples (#18651)
This commit is contained in:
parent
4628dfe6fe
commit
c631726b5c
|
@ -89,7 +89,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
|
|||
self.mnist_test = MNIST(self.data_dir, train=False)
|
||||
self.mnist_predict = MNIST(self.data_dir, train=False)
|
||||
mnist_full = MNIST(self.data_dir, train=True)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||
|
@ -146,7 +148,9 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
|
|||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == "fit":
|
||||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == "test":
|
||||
|
@ -230,7 +234,9 @@ There are also data operations you might want to perform on every GPU. Use :meth
|
|||
# Assign Train/val split(s) for use in Dataloaders
|
||||
if stage == "fit":
|
||||
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
# Assign Test split(s) for use in Dataloaders
|
||||
if stage == "test":
|
||||
|
|
|
@ -63,7 +63,8 @@ def validate(fabric, model, val_dataloader):
|
|||
|
||||
def get_dataloaders(dataset):
|
||||
n = len(dataset)
|
||||
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000])
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000], generator=generator)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)
|
||||
|
|
|
@ -156,7 +156,9 @@ class MyDataModule(LightningDataModule):
|
|||
super().__init__()
|
||||
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
|
||||
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
|
||||
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self):
|
||||
|
|
|
@ -105,7 +105,9 @@ class MyDataModule(LightningDataModule):
|
|||
super().__init__()
|
||||
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
|
||||
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
|
||||
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self):
|
||||
|
|
|
@ -48,7 +48,9 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
|
|||
# make assignments here (val/train/test split)
|
||||
# called on every process in DDP
|
||||
dataset = RandomDataset(1, 100)
|
||||
self.train, self.val, self.test = data.random_split(dataset, [80, 10, 10])
|
||||
self.train, self.val, self.test = data.random_split(
|
||||
dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return data.DataLoader(self.train)
|
||||
|
|
|
@ -200,7 +200,9 @@ class MNISTDataModule(LightningDataModule):
|
|||
dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra)
|
||||
assert isinstance(dataset, Sized)
|
||||
train_length = len(dataset)
|
||||
self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
|
||||
self.dataset_train, self.dataset_val = random_split(
|
||||
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
"""MNIST train set removes a subset to use for validation."""
|
||||
|
|
Loading…
Reference in New Issue