[feat] Add Trainer(stochastic_weight_avg=True/False) (#6038)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
chaton 2021-02-17 23:10:05 +00:00 committed by GitHub
parent 8d7ac8f0f8
commit c9622bafe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 27 deletions

View File

@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981)) - Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
### Changed ### Changed
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

View File

@ -37,7 +37,6 @@ __all__ = [
'ModelPruning', 'ModelPruning',
'ProgressBar', 'ProgressBar',
'ProgressBarBase', 'ProgressBarBase',
'ModelPruning',
'QuantizationAwareTraining', 'QuantizationAwareTraining',
'StochasticWeightAveraging', 'StochasticWeightAveraging',
] ]

View File

@ -23,6 +23,7 @@ from torch import nn
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -96,8 +97,10 @@ class StochasticWeightAveraging(Callback):
raise MisconfigurationException(err_msg) raise MisconfigurationException(err_msg)
if ( if (
swa_lrs is not None and (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0 not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
)
): ):
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
@ -131,11 +134,13 @@ class StochasticWeightAveraging(Callback):
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# copy the model before moving it to accelerator device. # copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module) self._average_model = deepcopy(pl_module)
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers lr_schedulers = trainer.lr_schedulers
if len(optimizers) > 1: if len(optimizers) != 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.") raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
if len(lr_schedulers) > 1: if len(lr_schedulers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
@ -156,18 +161,37 @@ class StochasticWeightAveraging(Callback):
self._average_model = self._average_model.to(self._device or pl_module.device) self._average_model = self._average_model.to(self._device or pl_module.device)
optimizers = trainer.optimizers optimizers = trainer.optimizers
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]
elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs
else:
initial_lr = self._swa_lrs[0]
param_group["initial_lr"] = initial_lr
self._swa_lrs = initial_lr
self._swa_scheduler = SWALR( self._swa_scheduler = SWALR(
optimizers[0], optimizers[0],
swa_lr=self._swa_lrs, swa_lr=initial_lr,
anneal_epochs=self._annealing_epochs, anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy, anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
) )
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}") rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
else:
_scheduler_config = _get_default_scheduler_config()
_scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

View File

@ -14,7 +14,12 @@
import os import os
from typing import List, Union from typing import List, Union
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks import (
Callback,
ModelCheckpoint,
ProgressBar,
ProgressBarBase,
)
from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -34,12 +39,14 @@ class CallbackConnector:
default_root_dir, default_root_dir,
weights_save_path, weights_save_path,
resume_from_checkpoint, resume_from_checkpoint,
stochastic_weight_avg,
): ):
self.trainer.resume_from_checkpoint = resume_from_checkpoint self.trainer.resume_from_checkpoint = resume_from_checkpoint
# init folder paths for checkpoint + weights save callbacks # init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd() self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
self.trainer._stochastic_weight_avg = stochastic_weight_avg
# init callbacks # init callbacks
if isinstance(callbacks, Callback): if isinstance(callbacks, Callback):
@ -50,6 +57,9 @@ class CallbackConnector:
# pass through the required args to figure out defaults # pass through the required args to figure out defaults
self.configure_checkpoint_callbacks(checkpoint_callback) self.configure_checkpoint_callbacks(checkpoint_callback)
# configure swa callback
self._configure_swa_callbacks()
# init progress bar # init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
@ -76,6 +86,15 @@ class CallbackConnector:
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min')) self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
def configure_progress_bar(self, refresh_rate=None, process_position=0): def configure_progress_bar(self, refresh_rate=None, process_position=0):
if os.getenv('COLAB_GPU') and refresh_rate is None: if os.getenv('COLAB_GPU') and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value # smaller refresh rate on colab causes crashes, choose a higher value

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from abc import ABC from abc import ABC
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Dict, Any
import torch import torch
from torch import optim from torch import optim
@ -98,15 +98,7 @@ class TrainerOptimizersMixin(ABC):
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information # Convert each scheduler into dict structure with relevant information
lr_schedulers = [] lr_schedulers = []
default_config = { default_config = _get_default_scheduler_config()
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': monitor, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
for scheduler in schedulers: for scheduler in schedulers:
if isinstance(scheduler, dict): if isinstance(scheduler, dict):
# check provided keys # check provided keys
@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
raise MisconfigurationException( raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`." "Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
) )
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}

View File

@ -139,6 +139,7 @@ class Trainer(
move_metrics_to_cpu: bool = False, move_metrics_to_cpu: bool = False,
enable_pl_optimizer: bool = None, # todo: remove in v1.3 enable_pl_optimizer: bool = None, # todo: remove in v1.3
multiple_trainloader_mode: str = 'max_size_cycle', multiple_trainloader_mode: str = 'max_size_cycle',
stochastic_weight_avg: bool = False
): ):
r""" r"""
Customize every aspect of training via flags Customize every aspect of training via flags
@ -297,6 +298,10 @@ class Trainer(
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets. reload when reaching the minimum length of datasets.
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`
""" """
super().__init__() super().__init__()
self._running_stage = None self._running_stage = None
@ -333,13 +338,8 @@ class Trainer(
# init callbacks # init callbacks
# Declare attributes to be set in callback_connector on_trainer_init # Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init( self.callback_connector.on_trainer_init(
callbacks, callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
checkpoint_callback, weights_save_path, resume_from_checkpoint, stochastic_weight_avg
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint,
) )
# hook # hook

View File

@ -157,3 +157,32 @@ def test_swa_raises():
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1) StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"): with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1]) StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
@pytest.mark.parametrize('use_callbacks', [False, True])
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
stochastic_weight_avg=stochastic_weight_avg,
limit_train_batches=4,
limit_val_batches=4,
max_epochs=2,
)
trainer.fit(model)
if use_callbacks or stochastic_weight_avg:
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
else:
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)