From cb2a3265e5eb329a48fb44df6ab8fd74df62b85a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 5 Oct 2020 23:15:52 -0400 Subject: [PATCH] Fixes #2936 (no fix needed) (#3892) --- tests/trainer/test_optimizers.py | 50 +++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 33f0ca8ced..34112e7cbd 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -1,9 +1,10 @@ import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from tests.base import EvalModelTemplate from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base.boring_model import BoringModel def test_optimizer_with_scheduling(tmpdir): @@ -298,3 +299,50 @@ def test_init_optimizers_during_testing(tmpdir): assert len(trainer.lr_schedulers) == 0 assert len(trainer.optimizers) == 0 assert len(trainer.optimizer_frequencies) == 0 + + +def test_multiple_optimizers_callbacks(tmpdir): + """ + Tests that multiple optimizers can be used with callbacks + """ + class CB(Callback): + + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + pass + + def on_train_epoch_start(self, trainer, pl_module): + pass + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.layer_1 = torch.nn.Linear(32, 2) + self.layer_2 = torch.nn.Linear(32, 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + a = batch[0] + acc = self.layer_1(a) + else: + a = batch[0] + acc = self.layer_2(a) + + acc = self.loss(acc, acc) + return acc + + def configure_optimizers(self): + a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2) + b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2) + return a, b + + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + callbacks=[CB()], + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model)