diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f82cbe974..c6c38163ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394)) +- Fixed overfit_batches which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501)) + ## [0.9.0] - YYYY-MM-DD ### Added diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f7c53c1cbe..fe2f942138 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -167,6 +167,12 @@ class TrainerDataLoadingMixin(ABC): model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) + if (self.overfit_batches > 0): + if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): + rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.' + ' We are turning it off for you.') + self.train_dataloader = self.replace_sampler( + self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)) # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) @@ -247,7 +253,7 @@ class TrainerDataLoadingMixin(ABC): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: - rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.' + rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.' ' We are turning it off for you.') dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 95b7fd2067..b1716485ed 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -56,11 +56,13 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # get the training loader and batch # ------------------------------------------------------ + # Create a reference train dataloader without shuffling. train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) + (xa, ya) = next(iter(train_loader)) + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) full_train_samples = len(train_loader) num_train_samples = int(0.11 * full_train_samples) - (xa, ya) = next(iter(train_loader)) # ------------------------------------------------------ # set VAL and Test loaders @@ -87,7 +89,8 @@ def test_overfit_batch_limits(tmpdir): trainer = Trainer(overfit_batches=0.11) trainer.reset_train_dataloader(model) - assert trainer.train_dataloader is train_loader + # The dataloader should have been overwritten with a Sequential sampler. + assert trainer.train_dataloader is not train_loader assert trainer.num_training_batches == num_train_samples # make sure the loaders are the same