[feat] Add Trainer(stochastic_weight_avg=True/False) (#6038)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
8d7ac8f0f8
commit
c9622bafe0
|
@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))
|
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))
|
||||||
|
|
||||||
|
|
||||||
|
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
|
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
|
||||||
|
|
|
@ -37,7 +37,6 @@ __all__ = [
|
||||||
'ModelPruning',
|
'ModelPruning',
|
||||||
'ProgressBar',
|
'ProgressBar',
|
||||||
'ProgressBarBase',
|
'ProgressBarBase',
|
||||||
'ModelPruning',
|
|
||||||
'QuantizationAwareTraining',
|
'QuantizationAwareTraining',
|
||||||
'StochasticWeightAveraging',
|
'StochasticWeightAveraging',
|
||||||
]
|
]
|
||||||
|
|
|
@ -23,6 +23,7 @@ from torch import nn
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks.base import Callback
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
|
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
|
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
|
@ -96,8 +97,10 @@ class StochasticWeightAveraging(Callback):
|
||||||
raise MisconfigurationException(err_msg)
|
raise MisconfigurationException(err_msg)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
swa_lrs is not None and (
|
||||||
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
|
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
|
||||||
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
|
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
|
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
|
||||||
|
|
||||||
|
@ -131,11 +134,13 @@ class StochasticWeightAveraging(Callback):
|
||||||
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
|
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
|
||||||
# copy the model before moving it to accelerator device.
|
# copy the model before moving it to accelerator device.
|
||||||
self._average_model = deepcopy(pl_module)
|
self._average_model = deepcopy(pl_module)
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
|
||||||
optimizers = trainer.optimizers
|
optimizers = trainer.optimizers
|
||||||
lr_schedulers = trainer.lr_schedulers
|
lr_schedulers = trainer.lr_schedulers
|
||||||
|
|
||||||
if len(optimizers) > 1:
|
if len(optimizers) != 1:
|
||||||
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.")
|
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
|
||||||
|
|
||||||
if len(lr_schedulers) > 1:
|
if len(lr_schedulers) > 1:
|
||||||
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
|
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
|
||||||
|
@ -156,18 +161,37 @@ class StochasticWeightAveraging(Callback):
|
||||||
self._average_model = self._average_model.to(self._device or pl_module.device)
|
self._average_model = self._average_model.to(self._device or pl_module.device)
|
||||||
|
|
||||||
optimizers = trainer.optimizers
|
optimizers = trainer.optimizers
|
||||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
|
||||||
|
for param_group in optimizers[0].param_groups:
|
||||||
|
if self._swa_lrs is None:
|
||||||
|
initial_lr = param_group["lr"]
|
||||||
|
|
||||||
|
elif isinstance(self._swa_lrs, float):
|
||||||
|
initial_lr = self._swa_lrs
|
||||||
|
|
||||||
|
else:
|
||||||
|
initial_lr = self._swa_lrs[0]
|
||||||
|
|
||||||
|
param_group["initial_lr"] = initial_lr
|
||||||
|
|
||||||
|
self._swa_lrs = initial_lr
|
||||||
|
|
||||||
self._swa_scheduler = SWALR(
|
self._swa_scheduler = SWALR(
|
||||||
optimizers[0],
|
optimizers[0],
|
||||||
swa_lr=self._swa_lrs,
|
swa_lr=initial_lr,
|
||||||
anneal_epochs=self._annealing_epochs,
|
anneal_epochs=self._annealing_epochs,
|
||||||
anneal_strategy=self._annealing_strategy,
|
anneal_strategy=self._annealing_strategy,
|
||||||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
|
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if trainer.lr_schedulers:
|
||||||
|
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||||
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
|
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
|
||||||
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
|
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
|
||||||
|
else:
|
||||||
|
_scheduler_config = _get_default_scheduler_config()
|
||||||
|
_scheduler_config["scheduler"] = self._swa_scheduler
|
||||||
|
trainer.lr_schedulers.append(_scheduler_config)
|
||||||
|
|
||||||
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,12 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
|
from pytorch_lightning.callbacks import (
|
||||||
|
Callback,
|
||||||
|
ModelCheckpoint,
|
||||||
|
ProgressBar,
|
||||||
|
ProgressBarBase,
|
||||||
|
)
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
@ -34,12 +39,14 @@ class CallbackConnector:
|
||||||
default_root_dir,
|
default_root_dir,
|
||||||
weights_save_path,
|
weights_save_path,
|
||||||
resume_from_checkpoint,
|
resume_from_checkpoint,
|
||||||
|
stochastic_weight_avg,
|
||||||
):
|
):
|
||||||
self.trainer.resume_from_checkpoint = resume_from_checkpoint
|
self.trainer.resume_from_checkpoint = resume_from_checkpoint
|
||||||
|
|
||||||
# init folder paths for checkpoint + weights save callbacks
|
# init folder paths for checkpoint + weights save callbacks
|
||||||
self.trainer._default_root_dir = default_root_dir or os.getcwd()
|
self.trainer._default_root_dir = default_root_dir or os.getcwd()
|
||||||
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
|
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
|
||||||
|
self.trainer._stochastic_weight_avg = stochastic_weight_avg
|
||||||
|
|
||||||
# init callbacks
|
# init callbacks
|
||||||
if isinstance(callbacks, Callback):
|
if isinstance(callbacks, Callback):
|
||||||
|
@ -50,6 +57,9 @@ class CallbackConnector:
|
||||||
# pass through the required args to figure out defaults
|
# pass through the required args to figure out defaults
|
||||||
self.configure_checkpoint_callbacks(checkpoint_callback)
|
self.configure_checkpoint_callbacks(checkpoint_callback)
|
||||||
|
|
||||||
|
# configure swa callback
|
||||||
|
self._configure_swa_callbacks()
|
||||||
|
|
||||||
# init progress bar
|
# init progress bar
|
||||||
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
||||||
|
|
||||||
|
@ -76,6 +86,15 @@ class CallbackConnector:
|
||||||
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
|
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
|
||||||
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
|
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
|
||||||
|
|
||||||
|
def _configure_swa_callbacks(self):
|
||||||
|
if not self.trainer._stochastic_weight_avg:
|
||||||
|
return
|
||||||
|
|
||||||
|
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
|
||||||
|
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
|
||||||
|
if not existing_swa:
|
||||||
|
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
|
||||||
|
|
||||||
def configure_progress_bar(self, refresh_rate=None, process_position=0):
|
def configure_progress_bar(self, refresh_rate=None, process_position=0):
|
||||||
if os.getenv('COLAB_GPU') and refresh_rate is None:
|
if os.getenv('COLAB_GPU') and refresh_rate is None:
|
||||||
# smaller refresh rate on colab causes crashes, choose a higher value
|
# smaller refresh rate on colab causes crashes, choose a higher value
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
@ -98,15 +98,7 @@ class TrainerOptimizersMixin(ABC):
|
||||||
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
|
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
|
||||||
# Convert each scheduler into dict structure with relevant information
|
# Convert each scheduler into dict structure with relevant information
|
||||||
lr_schedulers = []
|
lr_schedulers = []
|
||||||
default_config = {
|
default_config = _get_default_scheduler_config()
|
||||||
'scheduler': None,
|
|
||||||
'name': None, # no custom name
|
|
||||||
'interval': 'epoch', # after epoch is over
|
|
||||||
'frequency': 1, # every epoch/batch
|
|
||||||
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
|
|
||||||
'monitor': monitor, # value to monitor for ReduceLROnPlateau
|
|
||||||
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
|
|
||||||
}
|
|
||||||
for scheduler in schedulers:
|
for scheduler in schedulers:
|
||||||
if isinstance(scheduler, dict):
|
if isinstance(scheduler, dict):
|
||||||
# check provided keys
|
# check provided keys
|
||||||
|
@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
|
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_scheduler_config() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
'scheduler': None,
|
||||||
|
'name': None, # no custom name
|
||||||
|
'interval': 'epoch', # after epoch is over
|
||||||
|
'frequency': 1, # every epoch/batch
|
||||||
|
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
|
||||||
|
'monitor': None, # value to monitor for ReduceLROnPlateau
|
||||||
|
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||||
|
}
|
||||||
|
|
|
@ -139,6 +139,7 @@ class Trainer(
|
||||||
move_metrics_to_cpu: bool = False,
|
move_metrics_to_cpu: bool = False,
|
||||||
enable_pl_optimizer: bool = None, # todo: remove in v1.3
|
enable_pl_optimizer: bool = None, # todo: remove in v1.3
|
||||||
multiple_trainloader_mode: str = 'max_size_cycle',
|
multiple_trainloader_mode: str = 'max_size_cycle',
|
||||||
|
stochastic_weight_avg: bool = False
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Customize every aspect of training via flags
|
Customize every aspect of training via flags
|
||||||
|
@ -297,6 +298,10 @@ class Trainer(
|
||||||
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
|
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
|
||||||
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
|
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
|
||||||
reload when reaching the minimum length of datasets.
|
reload when reaching the minimum length of datasets.
|
||||||
|
|
||||||
|
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
|
||||||
|
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._running_stage = None
|
self._running_stage = None
|
||||||
|
@ -333,13 +338,8 @@ class Trainer(
|
||||||
# init callbacks
|
# init callbacks
|
||||||
# Declare attributes to be set in callback_connector on_trainer_init
|
# Declare attributes to be set in callback_connector on_trainer_init
|
||||||
self.callback_connector.on_trainer_init(
|
self.callback_connector.on_trainer_init(
|
||||||
callbacks,
|
callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
|
||||||
checkpoint_callback,
|
weights_save_path, resume_from_checkpoint, stochastic_weight_avg
|
||||||
progress_bar_refresh_rate,
|
|
||||||
process_position,
|
|
||||||
default_root_dir,
|
|
||||||
weights_save_path,
|
|
||||||
resume_from_checkpoint,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# hook
|
# hook
|
||||||
|
|
|
@ -157,3 +157,32 @@ def test_swa_raises():
|
||||||
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
|
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
|
||||||
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
|
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
|
||||||
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
|
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
|
||||||
|
@pytest.mark.parametrize('use_callbacks', [False, True])
|
||||||
|
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||||
|
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
|
||||||
|
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
|
||||||
|
stochastic_weight_avg=stochastic_weight_avg,
|
||||||
|
limit_train_batches=4,
|
||||||
|
limit_val_batches=4,
|
||||||
|
max_epochs=2,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
if use_callbacks or stochastic_weight_avg:
|
||||||
|
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
|
||||||
|
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
|
||||||
|
else:
|
||||||
|
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)
|
||||||
|
|
Loading…
Reference in New Issue