diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d4eeddb5b..2de9810968 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981)) +- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 514782addd..f3787c1cb2 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -37,7 +37,6 @@ __all__ = [ 'ModelPruning', 'ProgressBar', 'ProgressBarBase', - 'ModelPruning', 'QuantizationAwareTraining', 'StochasticWeightAveraging', ] diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/swa.py index 2cd573e5a6..fc7a2c75c0 100644 --- a/pytorch_lightning/callbacks/swa.py +++ b/pytorch_lightning/callbacks/swa.py @@ -23,6 +23,7 @@ from torch import nn import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -96,8 +97,10 @@ class StochasticWeightAveraging(Callback): raise MisconfigurationException(err_msg) if ( - not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0 - or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + swa_lrs is not None and ( + not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0 + or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + ) ): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") @@ -131,11 +134,13 @@ class StochasticWeightAveraging(Callback): def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module) + + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - if len(optimizers) > 1: - raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.") + if len(optimizers) != 1: + raise MisconfigurationException("SWA currently works with 1 `optimizer`.") if len(lr_schedulers) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") @@ -156,18 +161,37 @@ class StochasticWeightAveraging(Callback): self._average_model = self._average_model.to(self._device or pl_module.device) optimizers = trainer.optimizers - lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + + for param_group in optimizers[0].param_groups: + if self._swa_lrs is None: + initial_lr = param_group["lr"] + + elif isinstance(self._swa_lrs, float): + initial_lr = self._swa_lrs + + else: + initial_lr = self._swa_lrs[0] + + param_group["initial_lr"] = initial_lr + + self._swa_lrs = initial_lr self._swa_scheduler = SWALR( optimizers[0], - swa_lr=self._swa_lrs, + swa_lr=initial_lr, anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 ) - rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}") - trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler + if trainer.lr_schedulers: + lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}") + trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler + else: + _scheduler_config = _get_default_scheduler_config() + _scheduler_config["scheduler"] = self._swa_scheduler + trainer.lr_schedulers.append(_scheduler_config) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 6ea75c23fe..14694c8f77 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -14,7 +14,12 @@ import os from typing import List, Union -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import ( + Callback, + ModelCheckpoint, + ProgressBar, + ProgressBarBase, +) from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -34,12 +39,14 @@ class CallbackConnector: default_root_dir, weights_save_path, resume_from_checkpoint, + stochastic_weight_avg, ): self.trainer.resume_from_checkpoint = resume_from_checkpoint # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir + self.trainer._stochastic_weight_avg = stochastic_weight_avg # init callbacks if isinstance(callbacks, Callback): @@ -50,6 +57,9 @@ class CallbackConnector: # pass through the required args to figure out defaults self.configure_checkpoint_callbacks(checkpoint_callback) + # configure swa callback + self._configure_swa_callbacks() + # init progress bar self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) @@ -76,6 +86,15 @@ class CallbackConnector: if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min')) + def _configure_swa_callbacks(self): + if not self.trainer._stochastic_weight_avg: + return + + from pytorch_lightning.callbacks.swa import StochasticWeightAveraging + existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] + if not existing_swa: + self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks + def configure_progress_bar(self, refresh_rate=None, process_position=0): if os.getenv('COLAB_GPU') and refresh_rate is None: # smaller refresh rate on colab causes crashes, choose a higher value diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index eaf2231f5d..ea881b796e 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Any import torch from torch import optim @@ -98,15 +98,7 @@ class TrainerOptimizersMixin(ABC): def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): # Convert each scheduler into dict structure with relevant information lr_schedulers = [] - default_config = { - 'scheduler': None, - 'name': None, # no custom name - 'interval': 'epoch', # after epoch is over - 'frequency': 1, # every epoch/batch - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler - 'monitor': monitor, # value to monitor for ReduceLROnPlateau - 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau - } + default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): # check provided keys @@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers): raise MisconfigurationException( "Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`." ) + + +def _get_default_scheduler_config() -> Dict[str, Any]: + return { + 'scheduler': None, + 'name': None, # no custom name + 'interval': 'epoch', # after epoch is over + 'frequency': 1, # every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': None, # value to monitor for ReduceLROnPlateau + 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + } diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9eb75c9cfc..b9a0cc92a1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -139,6 +139,7 @@ class Trainer( move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = None, # todo: remove in v1.3 multiple_trainloader_mode: str = 'max_size_cycle', + stochastic_weight_avg: bool = False ): r""" Customize every aspect of training via flags @@ -297,6 +298,10 @@ class Trainer( In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. + + stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) + _` + """ super().__init__() self._running_stage = None @@ -333,13 +338,8 @@ class Trainer( # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.callback_connector.on_trainer_init( - callbacks, - checkpoint_callback, - progress_bar_refresh_rate, - process_position, - default_root_dir, - weights_save_path, - resume_from_checkpoint, + callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, + weights_save_path, resume_from_checkpoint, stochastic_weight_avg ) # hook diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_swa.py index 72a4c4fc1a..337773994b 100644 --- a/tests/callbacks/test_swa.py +++ b/tests/callbacks/test_swa.py @@ -157,3 +157,32 @@ def test_swa_raises(): StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1) with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"): StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1]) + + +@pytest.mark.parametrize('stochastic_weight_avg', [False, True]) +@pytest.mark.parametrize('use_callbacks', [False, True]) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0") +def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg): + """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer""" + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None, + stochastic_weight_avg=stochastic_weight_avg, + limit_train_batches=4, + limit_val_batches=4, + max_epochs=2, + ) + trainer.fit(model) + if use_callbacks or stochastic_weight_avg: + assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1 + assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1) + else: + assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)