Add datamodule parameter to lr_find() (#3425)
* Add datamodule parameter to lr_find() * Fixed missing import * Move datamodule parameter to end * Add datamodule parameter test with auto_lr_find * Change test for datamodule parameter * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Fix lr_find documentation Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * formatting * Add description to datamodule param in lr_find * pep8: remove trailing whitespace on line 105 * added changelog Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
7c61fc7c27
commit
e4e60e9b82
|
@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)
|
||||
|
||||
- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
|
||||
|
|
|
@ -11,21 +11,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional, Sequence, List, Union
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import importlib
|
||||
from pytorch_lightning import _logger as log
|
||||
import numpy as np
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
|
||||
|
||||
|
||||
# check if ipywidgets is installed before importing tqdm.auto
|
||||
# to ensure it won't fail and a progress bar is displayed
|
||||
if importlib.util.find_spec('ipywidgets') is not None:
|
||||
|
@ -71,6 +73,7 @@ def lr_find(
|
|||
num_training: int = 100,
|
||||
mode: str = 'exponential',
|
||||
early_stop_threshold: float = 4.0,
|
||||
datamodule: Optional[LightningDataModule] = None,
|
||||
):
|
||||
r"""
|
||||
lr_find enables the user to do a range test of good initial learning rates,
|
||||
|
@ -81,7 +84,7 @@ def lr_find(
|
|||
|
||||
train_dataloader: A PyTorch
|
||||
DataLoader with training samples. If the model has
|
||||
a predefined train_dataloader method this will be skipped.
|
||||
a predefined train_dataloader method, this will be skipped.
|
||||
|
||||
min_lr: minimum learning rate to investigate
|
||||
|
||||
|
@ -98,6 +101,12 @@ def lr_find(
|
|||
loss at any point is larger than early_stop_threshold*best_loss
|
||||
then the search is stopped. To disable, set to None.
|
||||
|
||||
datamodule: An optional `LightningDataModule` which holds the training
|
||||
and validation dataloader(s). Note that the `train_dataloader` and
|
||||
`val_dataloaders` parameters cannot be used at the same time as
|
||||
this parameter, or a `MisconfigurationException` will be raised.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
# Setup model and trainer
|
||||
|
@ -167,7 +176,8 @@ def lr_find(
|
|||
# Fit, lr & loss logged in callback
|
||||
trainer.fit(model,
|
||||
train_dataloader=train_dataloader,
|
||||
val_dataloaders=val_dataloaders)
|
||||
val_dataloaders=val_dataloaders,
|
||||
datamodule=datamodule)
|
||||
|
||||
# Prompt if we stopped early
|
||||
if trainer.global_step != num_training:
|
||||
|
|
|
@ -15,6 +15,7 @@ from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
|
|||
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
|
||||
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from typing import Optional, List, Union
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -50,6 +51,7 @@ class Tuner:
|
|||
num_training: int = 100,
|
||||
mode: str = 'exponential',
|
||||
early_stop_threshold: float = 4.0,
|
||||
datamodule: Optional[LightningDataModule] = None
|
||||
):
|
||||
return lr_find(
|
||||
self.trainer,
|
||||
|
@ -60,7 +62,8 @@ class Tuner:
|
|||
max_lr,
|
||||
num_training,
|
||||
mode,
|
||||
early_stop_threshold
|
||||
early_stop_threshold,
|
||||
datamodule,
|
||||
)
|
||||
|
||||
def internal_find_lr(self, trainer, model: LightningModule):
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.datamodules import TrialMNISTDataModule
|
||||
|
||||
|
||||
def test_error_on_more_than_1_optimizer(tmpdir):
|
||||
|
@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir):
|
|||
'Learning rate was not altered after running learning rate finder'
|
||||
|
||||
|
||||
def test_datamodule_parameter(tmpdir):
|
||||
""" Test that the datamodule parameter works """
|
||||
|
||||
# trial datamodule
|
||||
dm = TrialMNISTDataModule(tmpdir)
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
before_lr = hparams.get('learning_rate')
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
)
|
||||
|
||||
lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
|
||||
after_lr = lrfinder.suggestion()
|
||||
model.learning_rate = after_lr
|
||||
|
||||
assert before_lr != after_lr, \
|
||||
'Learning rate was not altered after running learning rate finder'
|
||||
|
||||
|
||||
def test_accumulation_and_early_stopping(tmpdir):
|
||||
""" Test that early stopping of learning rate finder works, and that
|
||||
accumulation also works for this feature """
|
||||
|
|
Loading…
Reference in New Issue