522 lines
18 KiB
Python
Executable File
522 lines
18 KiB
Python
Executable File
"""
|
|
Trainer Learning Rate Finder
|
|
"""
|
|
import os
|
|
import importlib
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional, Sequence, Tuple, List, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
# check if ipywidgets is installed before importing tqdm.auto
|
|
# to ensure it won't fail and a progress bar is displayed
|
|
if importlib.util.find_spec('ipywidgets') is not None:
|
|
from tqdm.auto import tqdm
|
|
else:
|
|
from tqdm import tqdm
|
|
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.callbacks import Callback
|
|
from pytorch_lightning.loggers.base import DummyLogger
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
|
|
|
|
class TrainerLRFinderMixin(ABC):
|
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
# the proper values/initialisation should be done in child class
|
|
default_root_dir: str
|
|
progress_bar_callback: ...
|
|
global_step: int
|
|
total_batch_idx: int
|
|
on_gpu: bool
|
|
|
|
@abstractmethod
|
|
def save_checkpoint(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def restore(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def init_optimizers(self, *args) -> Tuple[List, List, List]:
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def fit(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
def _run_lr_finder_internally(self, model: LightningModule):
|
|
""" Call lr finder internally during Trainer.fit() """
|
|
lr_finder = self.lr_find(model)
|
|
lr = lr_finder.suggestion()
|
|
# TODO: log lr.results to self.logger
|
|
if isinstance(self.auto_lr_find, str):
|
|
# Try to find requested field, may be nested
|
|
if _nested_hasattr(model, self.auto_lr_find):
|
|
_nested_setattr(model, self.auto_lr_find, lr)
|
|
else:
|
|
raise MisconfigurationException(
|
|
f'`auto_lr_find` was set to {self.auto_lr_find}, however'
|
|
' could not find this as a field in `model.hparams`.')
|
|
else:
|
|
if hasattr(model, 'lr'):
|
|
model.lr = lr
|
|
elif hasattr(model, 'learning_rate'):
|
|
model.learning_rate = lr
|
|
else:
|
|
raise MisconfigurationException(
|
|
'When auto_lr_find is set to True, expects that hparams'
|
|
' either has field `lr` or `learning_rate` that can overridden')
|
|
log.info(f'Learning rate set to {lr}')
|
|
|
|
def lr_find(self,
|
|
model: LightningModule,
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
|
min_lr: float = 1e-8,
|
|
max_lr: float = 1,
|
|
num_training: int = 100,
|
|
mode: str = 'exponential',
|
|
early_stop_threshold: float = 4.0,
|
|
num_accumulation_steps=None):
|
|
r"""
|
|
lr_find enables the user to do a range test of good initial learning rates,
|
|
to reduce the amount of guesswork in picking a good starting learning rate.
|
|
|
|
Args:
|
|
model: Model to do range testing for
|
|
|
|
train_dataloader: A PyTorch
|
|
DataLoader with training samples. If the model has
|
|
a predefined train_dataloader method this will be skipped.
|
|
|
|
min_lr: minimum learning rate to investigate
|
|
|
|
max_lr: maximum learning rate to investigate
|
|
|
|
num_training: number of learning rates to test
|
|
|
|
mode: search strategy, either 'linear' or 'exponential'. If set to
|
|
'linear' the learning rate will be searched by linearly increasing
|
|
after each batch. If set to 'exponential', will increase learning
|
|
rate exponentially.
|
|
|
|
early_stop_threshold: threshold for stopping the search. If the
|
|
loss at any point is larger than early_stop_threshold*best_loss
|
|
then the search is stopped. To disable, set to None.
|
|
|
|
num_accumulation_steps: deprepecated, number of batches to calculate loss over.
|
|
Set trainer argument ``accumulate_grad_batches`` instead.
|
|
|
|
Example::
|
|
|
|
# Setup model and trainer
|
|
model = MyModelClass(hparams)
|
|
trainer = pl.Trainer()
|
|
|
|
# Run lr finder
|
|
lr_finder = trainer.lr_find(model, ...)
|
|
|
|
# Inspect results
|
|
fig = lr_finder.plot(); fig.show()
|
|
suggested_lr = lr_finder.suggestion()
|
|
|
|
# Overwrite lr and create new model
|
|
hparams.lr = suggested_lr
|
|
model = MyModelClass(hparams)
|
|
|
|
# Ready to train with new learning rate
|
|
trainer.fit(model)
|
|
|
|
"""
|
|
if num_accumulation_steps is not None:
|
|
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
|
|
" since v0.7.6 and will be removed in 0.9. Please"
|
|
" set trainer argument `accumulate_grad_batches` instead.",
|
|
DeprecationWarning)
|
|
|
|
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')
|
|
|
|
self.__lr_finder_dump_params(model)
|
|
|
|
# Prevent going into infinite loop
|
|
self.auto_lr_find = False
|
|
|
|
# Initialize lr finder object (stores results)
|
|
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
|
|
|
|
# Use special lr logger callback
|
|
self.callbacks = [_LRCallback(num_training,
|
|
early_stop_threshold,
|
|
progress_bar_refresh_rate=1)]
|
|
|
|
# No logging
|
|
self.logger = DummyLogger()
|
|
|
|
# Max step set to number of iterations
|
|
self.max_steps = num_training
|
|
|
|
# Disable standard progress bar for fit
|
|
if self.progress_bar_callback:
|
|
self.progress_bar_callback.disable()
|
|
|
|
# Disable standard checkpoint & early stopping
|
|
self.checkpoint_callback = False
|
|
self.early_stop_callback = None
|
|
|
|
# Required for saving the model
|
|
self.optimizers, self.schedulers = [], [],
|
|
self.model = model
|
|
|
|
# Dump model checkpoint
|
|
self.save_checkpoint(str(save_path))
|
|
|
|
# Configure optimizer and scheduler
|
|
optimizers, _, _ = self.init_optimizers(model)
|
|
|
|
if len(optimizers) != 1:
|
|
raise MisconfigurationException(
|
|
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
|
|
' learning rate finder only works with single optimizer')
|
|
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
|
|
|
|
# Fit, lr & loss logged in callback
|
|
self.fit(model,
|
|
train_dataloader=train_dataloader,
|
|
val_dataloaders=val_dataloaders)
|
|
|
|
# Prompt if we stopped early
|
|
if self.global_step != num_training:
|
|
log.info('LR finder stopped early due to diverging loss.')
|
|
|
|
# Transfer results from callback to lr finder object
|
|
lr_finder.results.update({'lr': self.callbacks[0].lrs,
|
|
'loss': self.callbacks[0].losses})
|
|
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose
|
|
|
|
# Reset model state
|
|
self.restore(str(save_path), on_gpu=self.on_gpu)
|
|
os.remove(save_path)
|
|
|
|
# Finish by resetting variables so trainer is ready to fit model
|
|
self.__lr_finder_restore_params(model)
|
|
if self.progress_bar_callback:
|
|
self.progress_bar_callback.enable()
|
|
|
|
return lr_finder
|
|
|
|
def __lr_finder_dump_params(self, model):
|
|
# Prevent going into infinite loop
|
|
self.__dumped_params = {
|
|
'auto_lr_find': self.auto_lr_find,
|
|
'callbacks': self.callbacks,
|
|
'logger': self.logger,
|
|
'max_steps': self.max_steps,
|
|
'checkpoint_callback': self.checkpoint_callback,
|
|
'early_stop_callback': self.early_stop_callback,
|
|
'configure_optimizers': model.configure_optimizers,
|
|
}
|
|
|
|
def __lr_finder_restore_params(self, model):
|
|
self.auto_lr_find = self.__dumped_params['auto_lr_find']
|
|
self.logger = self.__dumped_params['logger']
|
|
self.callbacks = self.__dumped_params['callbacks']
|
|
self.max_steps = self.__dumped_params['max_steps']
|
|
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
|
|
self.early_stop_callback = self.__dumped_params['early_stop_callback']
|
|
model.configure_optimizers = self.__dumped_params['configure_optimizers']
|
|
del self.__dumped_params
|
|
|
|
|
|
class _LRFinder(object):
|
|
""" LR finder object. This object stores the results of Trainer.lr_find().
|
|
|
|
Args:
|
|
mode: either `linear` or `exponential`, how to increase lr after each step
|
|
|
|
lr_min: lr to start search from
|
|
|
|
lr_max: lr to stop search
|
|
|
|
num_training: number of steps to take between lr_min and lr_max
|
|
|
|
Example::
|
|
# Run lr finder
|
|
lr_finder = trainer.lr_find(model)
|
|
|
|
# Results stored in
|
|
lr_finder.results
|
|
|
|
# Plot using
|
|
lr_finder.plot()
|
|
|
|
# Get suggestion
|
|
lr = lr_finder.suggestion()
|
|
"""
|
|
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
|
|
assert mode in ('linear', 'exponential'), \
|
|
'mode should be either `linear` or `exponential`'
|
|
|
|
self.mode = mode
|
|
self.lr_min = lr_min
|
|
self.lr_max = lr_max
|
|
self.num_training = num_training
|
|
|
|
self.results = {}
|
|
self._total_batch_idx = 0 # for debug purpose
|
|
|
|
def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
|
|
""" Construct a new `configure_optimizers()` method, that has a optimizer
|
|
with initial lr set to lr_min and a scheduler that will either
|
|
linearly or exponentially increase the lr to lr_max in num_training steps.
|
|
|
|
Args:
|
|
optimizer: instance of `torch.optim.Optimizer`
|
|
|
|
"""
|
|
new_lrs = [self.lr_min] * len(optimizer.param_groups)
|
|
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
|
|
param_group["lr"] = new_lr
|
|
param_group["initial_lr"] = new_lr
|
|
|
|
args = (optimizer, self.lr_max, self.num_training)
|
|
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
|
|
|
|
def configure_optimizers():
|
|
return [optimizer], [{'scheduler': scheduler,
|
|
'interval': 'step'}]
|
|
|
|
return configure_optimizers
|
|
|
|
def plot(self, suggest: bool = False, show: bool = False):
|
|
""" Plot results from lr_find run
|
|
Args:
|
|
suggest: if True, will mark suggested lr to use with a red point
|
|
|
|
show: if True, will show figure
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
lrs = self.results["lr"]
|
|
losses = self.results["loss"]
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
# Plot loss as a function of the learning rate
|
|
ax.plot(lrs, losses)
|
|
if self.mode == 'exponential':
|
|
ax.set_xscale("log")
|
|
ax.set_xlabel("Learning rate")
|
|
ax.set_ylabel("Loss")
|
|
|
|
if suggest:
|
|
_ = self.suggestion()
|
|
if self._optimal_idx:
|
|
ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx],
|
|
markersize=10, marker='o', color='red')
|
|
|
|
if show:
|
|
plt.show()
|
|
|
|
return fig
|
|
|
|
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
|
|
""" This will propose a suggestion for choice of initial learning rate
|
|
as the point with the steepest negative gradient.
|
|
|
|
Returns:
|
|
lr: suggested initial learning rate to use
|
|
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
|
|
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
|
|
|
|
"""
|
|
try:
|
|
loss = np.array(self.results["loss"][skip_begin:-skip_end])
|
|
loss = loss[np.isfinite(loss)]
|
|
min_grad = np.gradient(loss).argmin()
|
|
self._optimal_idx = min_grad + skip_begin
|
|
return self.results["lr"][self._optimal_idx]
|
|
except Exception:
|
|
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
|
|
self._optimal_idx = None
|
|
|
|
|
|
class _LRCallback(Callback):
|
|
""" Special callback used by the learning rate finder. This callbacks log
|
|
the learning rate before each batch and log the corresponding loss after
|
|
each batch.
|
|
|
|
Args:
|
|
num_training: number of iterations done by the learning rate finder
|
|
early_stop_threshold: threshold for stopping the search. If the
|
|
loss at any point is larger than ``early_stop_threshold*best_loss``
|
|
then the search is stopped. To disable, set to ``None``.
|
|
progress_bar_refresh_rate: rate to refresh the progress bar for
|
|
the learning rate finder
|
|
beta: smoothing value, the loss being logged is a running average of
|
|
loss values logged until now. ``beta`` controls the forget rate i.e.
|
|
if ``beta=0`` all past information is ignored.
|
|
|
|
"""
|
|
def __init__(self, num_training: int,
|
|
early_stop_threshold: float = 4.0,
|
|
progress_bar_refresh_rate: int = 0,
|
|
beta: float = 0.98):
|
|
self.num_training = num_training
|
|
self.early_stop_threshold = early_stop_threshold
|
|
self.beta = beta
|
|
self.losses = []
|
|
self.lrs = []
|
|
self.avg_loss = 0.0
|
|
self.best_loss = 0.0
|
|
self.progress_bar_refresh_rate = progress_bar_refresh_rate
|
|
self.progress_bar = None
|
|
|
|
def on_batch_start(self, trainer, pl_module):
|
|
""" Called before each training batch, logs the lr that will be used """
|
|
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
|
|
return
|
|
|
|
if self.progress_bar_refresh_rate and self.progress_bar is None:
|
|
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
|
|
|
|
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
|
|
|
|
def on_batch_end(self, trainer, pl_module):
|
|
""" Called when the training batch ends, logs the calculated loss """
|
|
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
|
|
return
|
|
|
|
if self.progress_bar:
|
|
self.progress_bar.update()
|
|
|
|
current_loss = trainer.running_loss.last().item()
|
|
current_step = trainer.global_step + 1 # remove the +1 in 1.0
|
|
|
|
# Avg loss (loss with momentum) + smoothing
|
|
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
|
|
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
|
|
|
|
# Check if we diverging
|
|
if self.early_stop_threshold is not None:
|
|
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
|
|
trainer.max_steps = current_step # stop signal
|
|
if self.progress_bar:
|
|
self.progress_bar.close()
|
|
|
|
# Save best loss for diverging checking
|
|
if smoothed_loss < self.best_loss or current_step == 1:
|
|
self.best_loss = smoothed_loss
|
|
|
|
self.losses.append(smoothed_loss)
|
|
|
|
|
|
class _LinearLR(_LRScheduler):
|
|
"""Linearly increases the learning rate between two boundaries
|
|
over a number of iterations.
|
|
Arguments:
|
|
|
|
optimizer: wrapped optimizer.
|
|
|
|
end_lr: the final learning rate.
|
|
|
|
num_iter: the number of iterations over which the test occurs.
|
|
|
|
last_epoch: the index of last epoch. Default: -1.
|
|
"""
|
|
last_epoch: int
|
|
base_lrs: Sequence
|
|
|
|
def __init__(self,
|
|
optimizer: torch.optim.Optimizer,
|
|
end_lr: float,
|
|
num_iter: int,
|
|
last_epoch: int = -1):
|
|
self.end_lr = end_lr
|
|
self.num_iter = num_iter
|
|
super(_LinearLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
curr_iter = self.last_epoch + 1
|
|
r = curr_iter / self.num_iter
|
|
|
|
if self.last_epoch > 0:
|
|
val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
|
|
else:
|
|
val = [base_lr for base_lr in self.base_lrs]
|
|
self._lr = val
|
|
return val
|
|
|
|
@property
|
|
def lr(self):
|
|
return self._lr
|
|
|
|
|
|
class _ExponentialLR(_LRScheduler):
|
|
"""Exponentially increases the learning rate between two boundaries
|
|
over a number of iterations.
|
|
|
|
Arguments:
|
|
|
|
optimizer: wrapped optimizer.
|
|
|
|
end_lr: the final learning rate.
|
|
|
|
num_iter: the number of iterations over which the test occurs.
|
|
|
|
last_epoch: the index of last epoch. Default: -1.
|
|
"""
|
|
last_epoch: int
|
|
base_lrs: Sequence
|
|
|
|
def __init__(self,
|
|
optimizer: torch.optim.Optimizer,
|
|
end_lr: float,
|
|
num_iter: int,
|
|
last_epoch: int = -1):
|
|
self.end_lr = end_lr
|
|
self.num_iter = num_iter
|
|
super(_ExponentialLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
curr_iter = self.last_epoch + 1
|
|
r = curr_iter / self.num_iter
|
|
|
|
if self.last_epoch > 0:
|
|
val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
|
|
else:
|
|
val = [base_lr for base_lr in self.base_lrs]
|
|
self._lr = val
|
|
return val
|
|
|
|
@property
|
|
def lr(self):
|
|
return self._lr
|
|
|
|
|
|
def _nested_hasattr(obj, path):
|
|
parts = path.split(".")
|
|
for part in parts:
|
|
if hasattr(obj, part):
|
|
obj = getattr(obj, part)
|
|
else:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def _nested_setattr(obj, path, val):
|
|
parts = path.split(".")
|
|
for part in parts[:-1]:
|
|
if hasattr(obj, part):
|
|
obj = getattr(obj, part)
|
|
setattr(obj, parts[-1], val)
|