[feat] Add Trainer(stochastic_weight_avg=True/False) (#6038)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
chaton 2021-02-17 23:10:05 +00:00 committed by GitHub
parent 8d7ac8f0f8
commit c9622bafe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 27 deletions

View File

@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
### Changed
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

View File

@ -37,7 +37,6 @@ __all__ = [
'ModelPruning',
'ProgressBar',
'ProgressBarBase',
'ModelPruning',
'QuantizationAwareTraining',
'StochasticWeightAveraging',
]

View File

@ -23,6 +23,7 @@ from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -96,8 +97,10 @@ class StochasticWeightAveraging(Callback):
raise MisconfigurationException(err_msg)
if (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
swa_lrs is not None and (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
)
):
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
@ -131,11 +134,13 @@ class StochasticWeightAveraging(Callback):
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers
if len(optimizers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.")
if len(optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
if len(lr_schedulers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
@ -156,18 +161,37 @@ class StochasticWeightAveraging(Callback):
self._average_model = self._average_model.to(self._device or pl_module.device)
optimizers = trainer.optimizers
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]
elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs
else:
initial_lr = self._swa_lrs[0]
param_group["initial_lr"] = initial_lr
self._swa_lrs = initial_lr
self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=self._swa_lrs,
swa_lr=initial_lr,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
else:
_scheduler_config = _get_default_scheduler_config()
_scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

View File

@ -14,7 +14,12 @@
import os
from typing import List, Union
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import (
Callback,
ModelCheckpoint,
ProgressBar,
ProgressBarBase,
)
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -34,12 +39,14 @@ class CallbackConnector:
default_root_dir,
weights_save_path,
resume_from_checkpoint,
stochastic_weight_avg,
):
self.trainer.resume_from_checkpoint = resume_from_checkpoint
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
self.trainer._stochastic_weight_avg = stochastic_weight_avg
# init callbacks
if isinstance(callbacks, Callback):
@ -50,6 +57,9 @@ class CallbackConnector:
# pass through the required args to figure out defaults
self.configure_checkpoint_callbacks(checkpoint_callback)
# configure swa callback
self._configure_swa_callbacks()
# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
@ -76,6 +86,15 @@ class CallbackConnector:
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
def configure_progress_bar(self, refresh_rate=None, process_position=0):
if os.getenv('COLAB_GPU') and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value

View File

@ -13,7 +13,7 @@
# limitations under the License.
from abc import ABC
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Any
import torch
from torch import optim
@ -98,15 +98,7 @@ class TrainerOptimizersMixin(ABC):
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': monitor, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
)
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}

View File

@ -139,6 +139,7 @@ class Trainer(
move_metrics_to_cpu: bool = False,
enable_pl_optimizer: bool = None, # todo: remove in v1.3
multiple_trainloader_mode: str = 'max_size_cycle',
stochastic_weight_avg: bool = False
):
r"""
Customize every aspect of training via flags
@ -297,6 +298,10 @@ class Trainer(
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`
"""
super().__init__()
self._running_stage = None
@ -333,13 +338,8 @@ class Trainer(
# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
callbacks,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint,
callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
weights_save_path, resume_from_checkpoint, stochastic_weight_avg
)
# hook

View File

@ -157,3 +157,32 @@ def test_swa_raises():
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
@pytest.mark.parametrize('use_callbacks', [False, True])
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
stochastic_weight_avg=stochastic_weight_avg,
limit_train_batches=4,
limit_val_batches=4,
max_epochs=2,
)
trainer.fit(model)
if use_callbacks or stochastic_weight_avg:
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
else:
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)