parent
893bed741f
commit
cb2a3265e5
|
@ -1,9 +1,10 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_optimizer_with_scheduling(tmpdir):
|
||||
|
@ -298,3 +299,50 @@ def test_init_optimizers_during_testing(tmpdir):
|
|||
assert len(trainer.lr_schedulers) == 0
|
||||
assert len(trainer.optimizers) == 0
|
||||
assert len(trainer.optimizer_frequencies) == 0
|
||||
|
||||
|
||||
def test_multiple_optimizers_callbacks(tmpdir):
|
||||
"""
|
||||
Tests that multiple optimizers can be used with callbacks
|
||||
"""
|
||||
class CB(Callback):
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_1 = torch.nn.Linear(32, 2)
|
||||
self.layer_2 = torch.nn.Linear(32, 2)
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
if optimizer_idx == 0:
|
||||
a = batch[0]
|
||||
acc = self.layer_1(a)
|
||||
else:
|
||||
a = batch[0]
|
||||
acc = self.layer_2(a)
|
||||
|
||||
acc = self.loss(acc, acc)
|
||||
return acc
|
||||
|
||||
def configure_optimizers(self):
|
||||
a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2)
|
||||
b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2)
|
||||
return a, b
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
trainer = Trainer(
|
||||
callbacks=[CB()],
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue