Disable train dataloader shuffle when overfit_batches is active. (#3501)
* Disable train dataloader shuffle when overfit_batches is active. * pep8 Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
4dc4c8cfa5
commit
b5dc6998ae
|
@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394))
|
||||
|
||||
- Fixed overfit_batches which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))
|
||||
|
||||
## [0.9.0] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
|
|
@ -167,6 +167,12 @@ class TrainerDataLoadingMixin(ABC):
|
|||
model: The current `LightningModule`
|
||||
"""
|
||||
self.train_dataloader = self.request_dataloader(model.train_dataloader)
|
||||
if (self.overfit_batches > 0):
|
||||
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
|
||||
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
|
||||
' We are turning it off for you.')
|
||||
self.train_dataloader = self.replace_sampler(
|
||||
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
|
||||
|
||||
# debugging
|
||||
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
|
||||
|
@ -247,7 +253,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
# when overfitting, the dataloader should not have sampler
|
||||
if self.overfit_batches > 0:
|
||||
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
|
||||
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
|
||||
' We are turning it off for you.')
|
||||
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
|
||||
|
||||
|
|
|
@ -56,11 +56,13 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# ------------------------------------------------------
|
||||
# get the training loader and batch
|
||||
# ------------------------------------------------------
|
||||
# Create a reference train dataloader without shuffling.
|
||||
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
|
||||
(xa, ya) = next(iter(train_loader))
|
||||
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
|
||||
full_train_samples = len(train_loader)
|
||||
num_train_samples = int(0.11 * full_train_samples)
|
||||
|
||||
(xa, ya) = next(iter(train_loader))
|
||||
|
||||
# ------------------------------------------------------
|
||||
# set VAL and Test loaders
|
||||
|
@ -87,7 +89,8 @@ def test_overfit_batch_limits(tmpdir):
|
|||
|
||||
trainer = Trainer(overfit_batches=0.11)
|
||||
trainer.reset_train_dataloader(model)
|
||||
assert trainer.train_dataloader is train_loader
|
||||
# The dataloader should have been overwritten with a Sequential sampler.
|
||||
assert trainer.train_dataloader is not train_loader
|
||||
assert trainer.num_training_batches == num_train_samples
|
||||
|
||||
# make sure the loaders are the same
|
||||
|
|
Loading…
Reference in New Issue