[feat] Add Trainer(stochastic_weight_avg=True/False) (#6038)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
8d7ac8f0f8
commit
c9622bafe0
|
@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))
|
||||
|
||||
|
||||
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
|
||||
|
|
|
@ -37,7 +37,6 @@ __all__ = [
|
|||
'ModelPruning',
|
||||
'ProgressBar',
|
||||
'ProgressBarBase',
|
||||
'ModelPruning',
|
||||
'QuantizationAwareTraining',
|
||||
'StochasticWeightAveraging',
|
||||
]
|
||||
|
|
|
@ -23,6 +23,7 @@ from torch import nn
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -96,8 +97,10 @@ class StochasticWeightAveraging(Callback):
|
|||
raise MisconfigurationException(err_msg)
|
||||
|
||||
if (
|
||||
swa_lrs is not None and (
|
||||
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
|
||||
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
|
||||
)
|
||||
):
|
||||
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
|
||||
|
||||
|
@ -131,11 +134,13 @@ class StochasticWeightAveraging(Callback):
|
|||
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
|
||||
# copy the model before moving it to accelerator device.
|
||||
self._average_model = deepcopy(pl_module)
|
||||
|
||||
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
|
||||
optimizers = trainer.optimizers
|
||||
lr_schedulers = trainer.lr_schedulers
|
||||
|
||||
if len(optimizers) > 1:
|
||||
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.")
|
||||
if len(optimizers) != 1:
|
||||
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
|
||||
|
||||
if len(lr_schedulers) > 1:
|
||||
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
|
||||
|
@ -156,18 +161,37 @@ class StochasticWeightAveraging(Callback):
|
|||
self._average_model = self._average_model.to(self._device or pl_module.device)
|
||||
|
||||
optimizers = trainer.optimizers
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
|
||||
for param_group in optimizers[0].param_groups:
|
||||
if self._swa_lrs is None:
|
||||
initial_lr = param_group["lr"]
|
||||
|
||||
elif isinstance(self._swa_lrs, float):
|
||||
initial_lr = self._swa_lrs
|
||||
|
||||
else:
|
||||
initial_lr = self._swa_lrs[0]
|
||||
|
||||
param_group["initial_lr"] = initial_lr
|
||||
|
||||
self._swa_lrs = initial_lr
|
||||
|
||||
self._swa_scheduler = SWALR(
|
||||
optimizers[0],
|
||||
swa_lr=self._swa_lrs,
|
||||
swa_lr=initial_lr,
|
||||
anneal_epochs=self._annealing_epochs,
|
||||
anneal_strategy=self._annealing_strategy,
|
||||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
|
||||
)
|
||||
|
||||
if trainer.lr_schedulers:
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
|
||||
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
|
||||
else:
|
||||
_scheduler_config = _get_default_scheduler_config()
|
||||
_scheduler_config["scheduler"] = self._swa_scheduler
|
||||
trainer.lr_schedulers.append(_scheduler_config)
|
||||
|
||||
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
||||
|
||||
|
|
|
@ -14,7 +14,12 @@
|
|||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import (
|
||||
Callback,
|
||||
ModelCheckpoint,
|
||||
ProgressBar,
|
||||
ProgressBarBase,
|
||||
)
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -34,12 +39,14 @@ class CallbackConnector:
|
|||
default_root_dir,
|
||||
weights_save_path,
|
||||
resume_from_checkpoint,
|
||||
stochastic_weight_avg,
|
||||
):
|
||||
self.trainer.resume_from_checkpoint = resume_from_checkpoint
|
||||
|
||||
# init folder paths for checkpoint + weights save callbacks
|
||||
self.trainer._default_root_dir = default_root_dir or os.getcwd()
|
||||
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
|
||||
self.trainer._stochastic_weight_avg = stochastic_weight_avg
|
||||
|
||||
# init callbacks
|
||||
if isinstance(callbacks, Callback):
|
||||
|
@ -50,6 +57,9 @@ class CallbackConnector:
|
|||
# pass through the required args to figure out defaults
|
||||
self.configure_checkpoint_callbacks(checkpoint_callback)
|
||||
|
||||
# configure swa callback
|
||||
self._configure_swa_callbacks()
|
||||
|
||||
# init progress bar
|
||||
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
||||
|
||||
|
@ -76,6 +86,15 @@ class CallbackConnector:
|
|||
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
|
||||
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
|
||||
|
||||
def _configure_swa_callbacks(self):
|
||||
if not self.trainer._stochastic_weight_avg:
|
||||
return
|
||||
|
||||
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
|
||||
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
|
||||
if not existing_swa:
|
||||
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
|
||||
|
||||
def configure_progress_bar(self, refresh_rate=None, process_position=0):
|
||||
if os.getenv('COLAB_GPU') and refresh_rate is None:
|
||||
# smaller refresh rate on colab causes crashes, choose a higher value
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
|
@ -98,15 +98,7 @@ class TrainerOptimizersMixin(ABC):
|
|||
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
|
||||
# Convert each scheduler into dict structure with relevant information
|
||||
lr_schedulers = []
|
||||
default_config = {
|
||||
'scheduler': None,
|
||||
'name': None, # no custom name
|
||||
'interval': 'epoch', # after epoch is over
|
||||
'frequency': 1, # every epoch/batch
|
||||
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
|
||||
'monitor': monitor, # value to monitor for ReduceLROnPlateau
|
||||
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||
}
|
||||
default_config = _get_default_scheduler_config()
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, dict):
|
||||
# check provided keys
|
||||
|
@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
|
|||
raise MisconfigurationException(
|
||||
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
|
||||
)
|
||||
|
||||
|
||||
def _get_default_scheduler_config() -> Dict[str, Any]:
|
||||
return {
|
||||
'scheduler': None,
|
||||
'name': None, # no custom name
|
||||
'interval': 'epoch', # after epoch is over
|
||||
'frequency': 1, # every epoch/batch
|
||||
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
|
||||
'monitor': None, # value to monitor for ReduceLROnPlateau
|
||||
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||
}
|
||||
|
|
|
@ -139,6 +139,7 @@ class Trainer(
|
|||
move_metrics_to_cpu: bool = False,
|
||||
enable_pl_optimizer: bool = None, # todo: remove in v1.3
|
||||
multiple_trainloader_mode: str = 'max_size_cycle',
|
||||
stochastic_weight_avg: bool = False
|
||||
):
|
||||
r"""
|
||||
Customize every aspect of training via flags
|
||||
|
@ -297,6 +298,10 @@ class Trainer(
|
|||
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
|
||||
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
|
||||
reload when reaching the minimum length of datasets.
|
||||
|
||||
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
|
||||
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self._running_stage = None
|
||||
|
@ -333,13 +338,8 @@ class Trainer(
|
|||
# init callbacks
|
||||
# Declare attributes to be set in callback_connector on_trainer_init
|
||||
self.callback_connector.on_trainer_init(
|
||||
callbacks,
|
||||
checkpoint_callback,
|
||||
progress_bar_refresh_rate,
|
||||
process_position,
|
||||
default_root_dir,
|
||||
weights_save_path,
|
||||
resume_from_checkpoint,
|
||||
callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
|
||||
weights_save_path, resume_from_checkpoint, stochastic_weight_avg
|
||||
)
|
||||
|
||||
# hook
|
||||
|
|
|
@ -157,3 +157,32 @@ def test_swa_raises():
|
|||
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
|
||||
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
|
||||
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
|
||||
@pytest.mark.parametrize('use_callbacks', [False, True])
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
|
||||
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
return optimizer
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
|
||||
stochastic_weight_avg=stochastic_weight_avg,
|
||||
limit_train_batches=4,
|
||||
limit_val_batches=4,
|
||||
max_epochs=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
if use_callbacks or stochastic_weight_avg:
|
||||
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
|
||||
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
|
||||
else:
|
||||
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)
|
||||
|
|
Loading…
Reference in New Issue