Fix dataloaders are not reset when tuning the model (#7566)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Xinyao(Alvin) Sun 2021-05-24 02:21:45 -06:00 committed by GitHub
parent 299f2c481b
commit 0c958c5a1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 21 deletions

View File

@ -120,6 +120,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))

View File

@ -160,7 +160,10 @@ def _run_power_scaling(
else:
raise # some other error not memory related
if not changed:
if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
else:
break
return new_size
@ -192,7 +195,10 @@ def _run_binsearch_scaling(
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
if not changed:
if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
else:
break
except RuntimeError as exception:

View File

@ -24,14 +24,14 @@ from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
class BatchSizeDataModule(BoringDataModule):
def __init__(self, batch_size=None):
def __init__(self, batch_size):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size
@ -42,21 +42,23 @@ class BatchSizeDataModule(BoringDataModule):
class BatchSizeModel(BoringModel):
def __init__(self, batch_size=None):
def __init__(self, batch_size):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
@pytest.mark.parametrize(["model_bs", "dm_bs"], [
(2, -1),
(2, 2),
(2, None),
(None, 2),
(16, 16),
])
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
model = BatchSizeModel(model_bs)
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16
if model_bs is not None:
assert model.batch_size == new_batch_size
if dm_bs == -1:
# datamodule batch size takes precedence
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
if dm_bs not in (-1, None):
assert datamodule.batch_size == new_batch_size
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
def test_model_reset_correctly(tmpdir):