Disable train dataloader shuffle when overfit_batches is active. (#3501)

* Disable train dataloader shuffle when overfit_batches is active.

* pep8

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Phil 2020-09-15 11:07:27 +02:00 committed by GitHub
parent 4dc4c8cfa5
commit b5dc6998ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 3 deletions

View File

@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394))
- Fixed overfit_batches which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))
## [0.9.0] - YYYY-MM-DD
### Added

View File

@ -167,6 +167,12 @@ class TrainerDataLoadingMixin(ABC):
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.')
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
@ -247,7 +253,7 @@ class TrainerDataLoadingMixin(ABC):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
' We are turning it off for you.')
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))

View File

@ -56,11 +56,13 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# get the training loader and batch
# ------------------------------------------------------
# Create a reference train dataloader without shuffling.
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
(xa, ya) = next(iter(train_loader))
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
full_train_samples = len(train_loader)
num_train_samples = int(0.11 * full_train_samples)
(xa, ya) = next(iter(train_loader))
# ------------------------------------------------------
# set VAL and Test loaders
@ -87,7 +89,8 @@ def test_overfit_batch_limits(tmpdir):
trainer = Trainer(overfit_batches=0.11)
trainer.reset_train_dataloader(model)
assert trainer.train_dataloader is train_loader
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
assert trainer.num_training_batches == num_train_samples
# make sure the loaders are the same