Fix tuner.scale_batch_size not finding batch size attribute when using datamodule (#5968)
This commit is contained in:
parent
680e83adab
commit
b2bcad1132
|
@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
|
||||
|
||||
|
||||
- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))
|
||||
|
||||
|
||||
## [1.2.1] - 2021-02-23
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -105,7 +105,7 @@ class TrainLoop:
|
|||
# provide rank to profiler
|
||||
self.trainer.profile_connector.on_train_start(self.trainer)
|
||||
|
||||
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
|
||||
# clean hparams
|
||||
if hasattr(model, "hparams"):
|
||||
parsing.clean_namespace(model.hparams)
|
||||
|
|
|
@ -33,13 +33,20 @@ class Tuner:
|
|||
self.trainer.auto_lr_find = auto_lr_find
|
||||
self.trainer.auto_scale_batch_size = auto_scale_batch_size
|
||||
|
||||
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
def setup_trainer(
|
||||
self,
|
||||
model: LightningModule,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
||||
datamodule: LightningDataModule = None,
|
||||
):
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
# setup data, etc...
|
||||
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# hook
|
||||
self.trainer.data_connector.prepare_data(model)
|
||||
|
||||
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# Run auto batch size scaling
|
||||
if self.trainer.auto_scale_batch_size:
|
||||
if isinstance(self.trainer.auto_scale_batch_size, bool):
|
||||
|
@ -104,6 +111,7 @@ class Tuner:
|
|||
or datamodule.
|
||||
|
||||
"""
|
||||
self.setup_trainer(model, **fit_kwargs)
|
||||
return scale_batch_size(
|
||||
self.trainer,
|
||||
model,
|
||||
|
@ -128,6 +136,7 @@ class Tuner:
|
|||
datamodule: Optional[LightningDataModule] = None,
|
||||
update_attr: bool = False,
|
||||
):
|
||||
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
|
||||
return lr_find(
|
||||
self.trainer,
|
||||
model,
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
class BatchSizeDataModule(BoringDataModule):
|
||||
|
||||
def __init__(self, batch_size=None):
|
||||
super().__init__()
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))
|
||||
|
||||
|
||||
class BatchSizeModel(BoringModel):
|
||||
|
||||
def __init__(self, batch_size=None):
|
||||
super().__init__()
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,datamodule", [
|
||||
(BatchSizeModel(2), None),
|
||||
(BatchSizeModel(2), BatchSizeDataModule(2)),
|
||||
(BatchSizeModel(2), BatchSizeDataModule(None)),
|
||||
(BatchSizeModel(None), BatchSizeDataModule(2)),
|
||||
]
|
||||
)
|
||||
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
|
||||
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=0,
|
||||
max_epochs=1,
|
||||
)
|
||||
tuner = Tuner(trainer)
|
||||
new_batch_size = tuner.scale_batch_size(
|
||||
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
|
||||
)
|
||||
assert new_batch_size == 16
|
||||
if hasattr(model, "batch_size"):
|
||||
assert model.batch_size == 16
|
||||
if datamodule is not None and hasattr(datamodule, "batch_size"):
|
||||
assert datamodule.batch_size == 16
|
Loading…
Reference in New Issue