lightning/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

387 lines
14 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import ContextManager, Optional
from unittest import mock
import pytest
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.swa_utils import SWALR
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.strategies import Strategy
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
def test_swa_callback_initial_state():
swa = StochasticWeightAveraging(
swa_lrs=0.01,
swa_epoch_start=0.1,
annealing_epochs=1,
annealing_strategy="linear",
avg_fn=sum,
)
assert swa._swa_lrs == 0.01
assert swa._swa_epoch_start == 0.1
assert swa._annealing_epochs == 1
assert swa._annealing_strategy == "linear"
assert swa._avg_fn == sum
assert swa._average_model is None
class SwaTestModel(BoringModel):
def __init__(
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
layers.append(nn.BatchNorm1d(32))
layers += [nn.ReLU(), nn.Linear(32, 2)]
self.layer = nn.Sequential(*layers)
self.interval = interval
self.iterable_dataset = iterable_dataset
self.crash_on_epoch = crash_on_epoch
def training_step(self, batch, batch_idx):
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
raise Exception("SWA crash test")
return super().training_step(batch, batch_idx)
def train_dataloader(self):
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
dset = dset_cls(32, 64)
return DataLoader(dset, batch_size=2)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
"interval": self.interval,
},
}
class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
first_epoch: Optional[int] = None
def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
return StochasticWeightAveraging.update_parameters(*args, **kwargs)
def transfer_weights(self, *args, **kwargs):
self.transfer_weights_calls += 1
return StochasticWeightAveraging.transfer_weights(*args, **kwargs)
def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
if self.first_epoch is None and not trainer.fit_loop.restarting:
# since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
# not update the model and just call the epoch-level hooks, for that reason, we check that we are not
# restarting before choosing the first epoch
self.first_epoch = trainer.current_epoch
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR)
assert trainer.lr_scheduler_configs[0].interval == "epoch"
assert trainer.lr_scheduler_configs[0].frequency == 1
def on_train_epoch_end(self, trainer, *args):
super().on_train_epoch_end(trainer, *args)
if self.swa_start <= trainer.current_epoch <= self.swa_end:
swa_epoch = trainer.current_epoch - self.swa_start
assert self.n_averaged == swa_epoch + 1
assert self._swa_scheduler is not None
# Scheduler is stepped once on initialization and then at the end of each epoch
assert self._swa_scheduler._step_count == swa_epoch + 2
elif trainer.current_epoch > self.swa_end:
assert self.n_averaged == self._max_epochs - self.swa_start
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
# make sure these are correctly set again
assert not trainer.fit_loop._skip_backward
assert trainer.accumulate_grad_batches == 2
assert trainer.num_training_batches == 5
if not isinstance(trainer.strategy.launcher, _MultiProcessingLauncher):
# check backward call count. the batchnorm update epoch should not backward
assert trainer.strategy.backward.call_count == (
(trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
)
# check call counts
first_swa_epoch = max(self.first_epoch, self.swa_start)
assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
assert self.transfer_weights_calls == 1
def train_with_swa(
tmpdir,
batchnorm=True,
strategy="auto",
accelerator="cpu",
devices=1,
interval="epoch",
iterable_dataset=False,
):
model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset)
swa_start = 2
max_epochs = 5
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
assert swa_callback.update_parameters_calls == 0
assert swa_callback.transfer_weights_calls == 0
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
enable_model_summary=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=0,
callbacks=[swa_callback],
accumulate_grad_batches=2,
strategy=strategy,
accelerator=accelerator,
devices=devices,
)
with _backward_patch(trainer):
trainer.fit(model)
# check the model is the expected
assert trainer.lightning_module == model
@RunIf(min_cuda_gpus=2, standalone=True)
def test_swa_callback_ddp(tmpdir):
train_with_swa(tmpdir, strategy="ddp", accelerator="gpu", devices=2)
@RunIf(min_cuda_gpus=2)
def test_swa_callback_ddp_spawn(tmpdir):
train_with_swa(tmpdir, strategy="ddp_spawn", accelerator="gpu", devices=2)
@RunIf(skip_windows=True)
def test_swa_callback_ddp_cpu(tmpdir):
train_with_swa(tmpdir, strategy="ddp_spawn", accelerator="cpu", devices=2)
@pytest.mark.parametrize(
"accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))]
)
def test_swa_callback_1_gpu(tmpdir, accelerator):
train_with_swa(tmpdir, accelerator=accelerator, devices=1)
@pytest.mark.parametrize("batchnorm", [True, False])
@pytest.mark.parametrize("iterable_dataset", [True, False])
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool):
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset)
@pytest.mark.parametrize("interval", ["epoch", "step"])
def test_swa_callback_scheduler_step(tmpdir, interval: str):
train_with_swa(tmpdir, interval=interval)
def test_swa_warns(tmpdir, caplog):
model = SwaTestModel(interval="step")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging(swa_lrs=1e-2))
with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"):
trainer.fit(model)
assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
def test_swa_raises():
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=1.5, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match="positive float, or a list of positive floats"):
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
def test_swa_deepcopy(tmpdir):
"""Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM."""
class TestSWA(StochasticWeightAveraging):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_called = False
def setup(self, trainer, pl_module, stage) -> None:
super().setup(trainer, pl_module, stage)
assert self._average_model.train_dataloader is not pl_module.train_dataloader
assert self._average_model.train_dataloader.__self__ == self._average_model
assert self._average_model._trainer is None
self.setup_called = True
model = BoringModel()
swa = TestSWA(swa_lrs=1e-2)
trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True)
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 2)))
assert swa.setup_called
def test_swa_multiple_lrs(tmpdir):
swa_lrs = [0.123, 0.321]
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)
self.on_train_epoch_start_called = False
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
def configure_optimizers(self):
params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}]
return torch.optim.Adam(params)
def on_train_epoch_start(self):
optimizer = trainer.optimizers[0]
assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2]
assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs
assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs
self.on_train_epoch_start_called = True
model = TestModel()
swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=swa_callback,
fast_dev_run=1,
)
trainer.fit(model)
assert model.on_train_epoch_start_called
def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False):
swa_start = 3
trainer_kwargs = {
"default_root_dir": tmpdir,
"max_epochs": 5,
"accelerator": "cpu",
"strategy": "ddp_spawn" if ddp else "auto",
"devices": 2 if ddp else 1,
"limit_train_batches": 5,
"limit_val_batches": 0,
"accumulate_grad_batches": 2,
"enable_progress_bar": False,
"logger": False,
}
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
trainer.fit(model)
checkpoint_dir = Path(tmpdir) / "checkpoints"
checkpoint_files = os.listdir(checkpoint_dir)
assert len(checkpoint_files) == 1
ckpt_path = str(checkpoint_dir / checkpoint_files[0])
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
with _backward_patch(trainer):
trainer.fit(resume_model, ckpt_path=ckpt_path)
class CustomSchedulerModel(SwaTestModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
def lr_lambda(current_step: int):
return 0.1
scheduler = LambdaLR(optimizer, lr_lambda, -1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": self.interval,
},
}
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch):
model = SwaTestModel(crash_on_epoch=crash_on_epoch)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch):
# Reproduces the bug reported in https://github.com/Lightning-AI/lightning/issues/11665
model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch)
resume_model = CustomSchedulerModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
@RunIf(skip_windows=True)
def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
model = SwaTestModel(crash_on_epoch=3)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True)
@pytest.mark.parametrize(
"strategy",
[
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
pytest.param("fsdp", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
model = SwaTestModel()
swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=5,
callbacks=[swa_callback],
strategy=strategy,
accelerator="gpu",
devices=1,
)
with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
trainer.fit(model)
def _backward_patch(trainer: Trainer) -> ContextManager:
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)