diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dcdea4c16..b94f39e66c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 88b87afcb9..7e737c424f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -105,7 +105,7 @@ class TrainLoop: # provide rank to profiler self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 78810141b1..b9fa9afe0e 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -33,13 +33,20 @@ class Tuner: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def tune(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_trainer( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: LightningDataModule = None, + ): + self.trainer.model_connector.copy_trainer_model_properties(model) # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook self.trainer.data_connector.prepare_data(model) + def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): @@ -104,6 +111,7 @@ class Tuner: or datamodule. """ + self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -128,6 +136,7 @@ class Tuner: datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py new file mode 100644 index 0000000000..ad7fc57092 --- /dev/null +++ b/tests/tuner/test_scale_batch_size.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from tests.helpers import BoringDataModule, BoringModel + + +class BatchSizeDataModule(BoringDataModule): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) + + +class BatchSizeModel(BoringModel): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + +@pytest.mark.parametrize( + "model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), + ] +) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + ) + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size( + model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule + ) + assert new_batch_size == 16 + if hasattr(model, "batch_size"): + assert model.batch_size == 16 + if datamodule is not None and hasattr(datamodule, "batch_size"): + assert datamodule.batch_size == 16