From c631726b5cb8886eeec425ab9cdf0d352c8e68ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 27 Sep 2023 05:09:05 -0700 Subject: [PATCH] Set the generator seed in `random_split` usages in the examples (#18651) --- docs/source-pytorch/data/datamodule.rst | 12 +++++++++--- examples/fabric/language_model/train.py | 3 ++- examples/pytorch/basics/autoencoder.py | 4 +++- examples/pytorch/basics/backbone_image_classifier.py | 4 +++- src/lightning/pytorch/core/datamodule.py | 4 +++- src/lightning/pytorch/demos/mnist_datamodule.py | 4 +++- 6 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/source-pytorch/data/datamodule.rst b/docs/source-pytorch/data/datamodule.rst index b6bb1c5467..44cbc81254 100644 --- a/docs/source-pytorch/data/datamodule.rst +++ b/docs/source-pytorch/data/datamodule.rst @@ -89,7 +89,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa self.mnist_test = MNIST(self.data_dir, train=False) self.mnist_predict = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) - self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.mnist_train, self.mnist_val = random_split( + mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size) @@ -146,7 +148,9 @@ Here's a more realistic, complex DataModule that shows how much more reusable th # Assign train/val datasets for use in dataloaders if stage == "fit": mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) - self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.mnist_train, self.mnist_val = random_split( + mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) # Assign test dataset for use in dataloader(s) if stage == "test": @@ -230,7 +234,9 @@ There are also data operations you might want to perform on every GPU. Use :meth # Assign Train/val split(s) for use in Dataloaders if stage == "fit": mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform) - self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.mnist_train, self.mnist_val = random_split( + mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) # Assign Test split(s) for use in Dataloaders if stage == "test": diff --git a/examples/fabric/language_model/train.py b/examples/fabric/language_model/train.py index 9dd4efa628..cafe6ceeb1 100644 --- a/examples/fabric/language_model/train.py +++ b/examples/fabric/language_model/train.py @@ -63,7 +63,8 @@ def validate(fabric, model, val_dataloader): def get_dataloaders(dataset): n = len(dataset) - train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000]) + generator = torch.Generator().manual_seed(42) + train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000], generator=generator) train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False) test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False) diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index 0224f0007d..b31a7392f5 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -156,7 +156,9 @@ class MyDataModule(LightningDataModule): super().__init__() dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) - self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) + self.mnist_train, self.mnist_val = random_split( + dataset, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) self.batch_size = batch_size def train_dataloader(self): diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index f6a4dd30a0..ad58f767b8 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -105,7 +105,9 @@ class MyDataModule(LightningDataModule): super().__init__() dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) - self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) + self.mnist_train, self.mnist_val = random_split( + dataset, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) self.batch_size = batch_size def train_dataloader(self): diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 32913c6015..afe2f0b110 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -48,7 +48,9 @@ class LightningDataModule(DataHooks, HyperparametersMixin): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) - self.train, self.val, self.test = data.random_split(dataset, [80, 10, 10]) + self.train, self.val, self.test = data.random_split( + dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) + ) def train_dataloader(self): return data.DataLoader(self.train) diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index fe500cc665..7c9050dcd5 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -200,7 +200,9 @@ class MNISTDataModule(LightningDataModule): dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra) assert isinstance(dataset, Sized) train_length = len(dataset) - self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) + self.dataset_train, self.dataset_val = random_split( + dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(42) + ) def train_dataloader(self) -> DataLoader: """MNIST train set removes a subset to use for validation."""