Add datamodule parameter to lr_find() (#3425)

* Add datamodule parameter to lr_find()

* Fixed missing import

* Move datamodule parameter to end

* Add datamodule parameter test with auto_lr_find

* Change test for datamodule parameter

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Fix lr_find documentation

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* formatting

* Add description to datamodule param in lr_find

* pep8: remove trailing whitespace on line 105

* added changelog

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
GimmickNG 2020-10-01 02:33:12 -06:00 committed by GitHub
parent 7c61fc7c27
commit e4e60e9b82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 10 deletions

View File

@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)
- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))
### Changed
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))

View File

@ -11,21 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union
import numpy as np
import torch
from typing import Optional, Sequence, List, Union
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.optim.lr_scheduler import _LRScheduler
import importlib
from pytorch_lightning import _logger as log
import numpy as np
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
@ -71,6 +73,7 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
):
r"""
lr_find enables the user to do a range test of good initial learning rates,
@ -81,7 +84,7 @@ def lr_find(
train_dataloader: A PyTorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
a predefined train_dataloader method, this will be skipped.
min_lr: minimum learning rate to investigate
@ -98,6 +101,12 @@ def lr_find(
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
datamodule: An optional `LightningDataModule` which holds the training
and validation dataloader(s). Note that the `train_dataloader` and
`val_dataloaders` parameters cannot be used at the same time as
this parameter, or a `MisconfigurationException` will be raised.
Example::
# Setup model and trainer
@ -167,7 +176,8 @@ def lr_find(
# Fit, lr & loss logged in callback
trainer.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)
val_dataloaders=val_dataloaders,
datamodule=datamodule)
# Prompt if we stopped early
if trainer.global_step != num_training:

View File

@ -15,6 +15,7 @@ from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from typing import Optional, List, Union
from torch.utils.data import DataLoader
@ -50,6 +51,7 @@ class Tuner:
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None
):
return lr_find(
self.trainer,
@ -60,7 +62,8 @@ class Tuner:
max_lr,
num_training,
mode,
early_stop_threshold
early_stop_threshold,
datamodule,
)
def internal_find_lr(self, trainer, model: LightningModule):

View File

@ -5,6 +5,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
def test_error_on_more_than_1_optimizer(tmpdir):
@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir):
'Learning rate was not altered after running learning rate finder'
def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """
# trial datamodule
dm = TrialMNISTDataModule(tmpdir)
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
before_lr = hparams.get('learning_rate')
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
)
lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr
assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
def test_accumulation_and_early_stopping(tmpdir):
""" Test that early stopping of learning rate finder works, and that
accumulation also works for this feature """