Fix tuner.scale_batch_size not finding batch size attribute when using datamodule (#5968)

This commit is contained in:
Adrian Wälchli 2021-03-14 09:16:19 +01:00 committed by GitHub
parent 680e83adab
commit b2bcad1132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 3 deletions

View File

@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))
## [1.2.1] - 2021-02-23
### Fixed

View File

@ -105,7 +105,7 @@ class TrainLoop:
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

View File

@ -33,13 +33,20 @@ class Tuner:
self.trainer.auto_lr_find = auto_lr_find
self.trainer.auto_scale_batch_size = auto_scale_batch_size
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_trainer(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: LightningDataModule = None,
):
self.trainer.model_connector.copy_trainer_model_properties(model)
# setup data, etc...
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
# hook
self.trainer.data_connector.prepare_data(model)
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, bool):
@ -104,6 +111,7 @@ class Tuner:
or datamodule.
"""
self.setup_trainer(model, **fit_kwargs)
return scale_batch_size(
self.trainer,
model,
@ -128,6 +136,7 @@ class Tuner:
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
return lr_find(
self.trainer,
model,

View File

@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from tests.helpers import BoringDataModule, BoringModel
class BatchSizeDataModule(BoringDataModule):
def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))
class BatchSizeModel(BoringModel):
def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size
@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16