Enforce an epoch scheduler interval when using SWA (#6588)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
7f91c5ebbc
commit
6dc1078822
|
@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
|
||||
|
||||
|
||||
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
|
||||
|
||||
|
||||
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
|
||||
|
||||
|
||||
|
|
|
@ -187,14 +187,15 @@ class StochasticWeightAveraging(Callback):
|
|||
anneal_strategy=self._annealing_strategy,
|
||||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
|
||||
)
|
||||
_scheduler_config = _get_default_scheduler_config()
|
||||
assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
|
||||
_scheduler_config["scheduler"] = self._swa_scheduler
|
||||
|
||||
if trainer.lr_schedulers:
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
|
||||
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
|
||||
trainer.lr_schedulers[0] = _scheduler_config
|
||||
else:
|
||||
_scheduler_config = _get_default_scheduler_config()
|
||||
_scheduler_config["scheduler"] = self._swa_scheduler
|
||||
trainer.lr_schedulers.append(_scheduler_config)
|
||||
|
||||
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
||||
|
|
|
@ -27,16 +27,18 @@ from tests.helpers.runif import RunIf
|
|||
|
||||
if _TORCH_GREATER_EQUAL_1_6:
|
||||
from pytorch_lightning.callbacks import StochasticWeightAveraging
|
||||
from torch.optim.swa_utils import SWALR
|
||||
|
||||
class SwaTestModel(BoringModel):
|
||||
|
||||
def __init__(self, batchnorm: bool = True):
|
||||
def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
|
||||
super().__init__()
|
||||
layers = [nn.Linear(32, 32)]
|
||||
if batchnorm:
|
||||
layers.append(nn.BatchNorm1d(32))
|
||||
layers += [nn.ReLU(), nn.Linear(32, 2)]
|
||||
self.layer = nn.Sequential(*layers)
|
||||
self.interval = interval
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.forward(batch)
|
||||
|
@ -46,6 +48,14 @@ if _TORCH_GREATER_EQUAL_1_6:
|
|||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(32, 64), batch_size=2)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
|
||||
"interval": self.interval,
|
||||
}
|
||||
|
||||
class SwaTestCallback(StochasticWeightAveraging):
|
||||
update_parameters_calls: int = 0
|
||||
transfer_weights_calls: int = 0
|
||||
|
@ -61,6 +71,10 @@ if _TORCH_GREATER_EQUAL_1_6:
|
|||
def on_train_epoch_start(self, trainer, *args):
|
||||
super().on_train_epoch_start(trainer, *args)
|
||||
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
|
||||
if self.swa_start <= trainer.current_epoch:
|
||||
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
|
||||
assert trainer.lr_schedulers[0]["interval"] == "epoch"
|
||||
assert trainer.lr_schedulers[0]["frequency"] == 1
|
||||
|
||||
def on_train_epoch_end(self, trainer, *args):
|
||||
super().on_train_epoch_end(trainer, *args)
|
||||
|
@ -89,8 +103,8 @@ if _TORCH_GREATER_EQUAL_1_6:
|
|||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
|
||||
model = SwaTestModel(batchnorm=batchnorm)
|
||||
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
|
||||
model = SwaTestModel(batchnorm=batchnorm, interval=interval)
|
||||
swa_start = 2
|
||||
max_epochs = 5
|
||||
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
|
||||
|
@ -140,6 +154,12 @@ def test_swa_callback(tmpdir, batchnorm: bool):
|
|||
train_with_swa(tmpdir, batchnorm=batchnorm)
|
||||
|
||||
|
||||
@RunIf(min_torch="1.6.0")
|
||||
@pytest.mark.parametrize("interval", ("epoch", "step"))
|
||||
def test_swa_callback_scheduler_step(tmpdir, interval: bool):
|
||||
train_with_swa(tmpdir, interval=interval)
|
||||
|
||||
|
||||
@RunIf(min_torch="1.6.0")
|
||||
def test_swa_raises():
|
||||
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
|
||||
|
|
Loading…
Reference in New Issue