Fix dataloaders are not reset when tuning the model (#7566)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
299f2c481b
commit
0c958c5a1f
|
@ -120,6 +120,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
|
||||
|
||||
|
||||
- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
|
||||
|
||||
|
|
|
@ -160,7 +160,10 @@ def _run_power_scaling(
|
|||
else:
|
||||
raise # some other error not memory related
|
||||
|
||||
if not changed:
|
||||
if changed:
|
||||
# Force the train dataloader to reset as the batch size has changed
|
||||
trainer.reset_train_dataloader(model)
|
||||
else:
|
||||
break
|
||||
return new_size
|
||||
|
||||
|
@ -192,7 +195,10 @@ def _run_binsearch_scaling(
|
|||
else:
|
||||
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
|
||||
|
||||
if not changed:
|
||||
if changed:
|
||||
# Force the train dataloader to reset as the batch size has changed
|
||||
trainer.reset_train_dataloader(model)
|
||||
else:
|
||||
break
|
||||
|
||||
except RuntimeError as exception:
|
||||
|
|
|
@ -24,14 +24,14 @@ from pytorch_lightning.tuner.tuning import Tuner
|
|||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
||||
from tests.helpers.datamodules import MNISTDataModule
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
class BatchSizeDataModule(BoringDataModule):
|
||||
|
||||
def __init__(self, batch_size=None):
|
||||
def __init__(self, batch_size):
|
||||
super().__init__()
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
|
@ -42,21 +42,23 @@ class BatchSizeDataModule(BoringDataModule):
|
|||
|
||||
class BatchSizeModel(BoringModel):
|
||||
|
||||
def __init__(self, batch_size=None):
|
||||
def __init__(self, batch_size):
|
||||
super().__init__()
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,datamodule", [
|
||||
(BatchSizeModel(2), None),
|
||||
(BatchSizeModel(2), BatchSizeDataModule(2)),
|
||||
(BatchSizeModel(2), BatchSizeDataModule(None)),
|
||||
(BatchSizeModel(None), BatchSizeDataModule(2)),
|
||||
]
|
||||
)
|
||||
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
|
||||
|
||||
@pytest.mark.parametrize(["model_bs", "dm_bs"], [
|
||||
(2, -1),
|
||||
(2, 2),
|
||||
(2, None),
|
||||
(None, 2),
|
||||
(16, 16),
|
||||
])
|
||||
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
|
||||
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
|
|||
max_epochs=1,
|
||||
)
|
||||
tuner = Tuner(trainer)
|
||||
new_batch_size = tuner.scale_batch_size(
|
||||
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
|
||||
)
|
||||
|
||||
model = BatchSizeModel(model_bs)
|
||||
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
|
||||
|
||||
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
|
||||
assert new_batch_size == 16
|
||||
if hasattr(model, "batch_size"):
|
||||
assert model.batch_size == 16
|
||||
if datamodule is not None and hasattr(datamodule, "batch_size"):
|
||||
assert datamodule.batch_size == 16
|
||||
|
||||
if model_bs is not None:
|
||||
assert model.batch_size == new_batch_size
|
||||
if dm_bs == -1:
|
||||
# datamodule batch size takes precedence
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
if dm_bs not in (-1, None):
|
||||
assert datamodule.batch_size == new_batch_size
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
|
||||
|
||||
def test_model_reset_correctly(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue