Fixes #2936 (no fix needed) (#3892)

This commit is contained in:
William Falcon 2020-10-05 23:15:52 -04:00 committed by GitHub
parent 893bed741f
commit cb2a3265e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 49 additions and 1 deletions

View File

@ -1,9 +1,10 @@
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
def test_optimizer_with_scheduling(tmpdir):
@ -298,3 +299,50 @@ def test_init_optimizers_during_testing(tmpdir):
assert len(trainer.lr_schedulers) == 0
assert len(trainer.optimizers) == 0
assert len(trainer.optimizer_frequencies) == 0
def test_multiple_optimizers_callbacks(tmpdir):
"""
Tests that multiple optimizers can be used with callbacks
"""
class CB(Callback):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
pass
def on_train_epoch_start(self, trainer, pl_module):
pass
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(32, 2)
self.layer_2 = torch.nn.Linear(32, 2)
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
a = batch[0]
acc = self.layer_1(a)
else:
a = batch[0]
acc = self.layer_2(a)
acc = self.loss(acc, acc)
return acc
def configure_optimizers(self):
a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2)
b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2)
return a, b
model = TestModel()
model.training_epoch_end = None
trainer = Trainer(
callbacks=[CB()],
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)