ref: bug fix with logging val epoch end + monitor (#3812)

* ref: fix metric err

* ref: fix metric err

* ref: fix metric err

* ref: merge

* ref: merge

* ref: merge

* ref: merge

* ref: decoupled ddp2

* ref: decoupled ddp2

* ref: decoupled ddp2

* ref: decoupled ddp2

* ref: decoupled ddp2

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix
This commit is contained in:
William Falcon 2020-10-03 12:33:29 -04:00 committed by GitHub
parent ed1450a293
commit d9bc95f83e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 212 additions and 56 deletions

View File

@ -106,9 +106,9 @@ class EarlyStopping(Callback):
def _validate_condition_metric(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Either add `{self.monitor}` to the return of'
' `validation_epoch_end` or modify your `EarlyStopping` callback to use any of the'
f' which is not available. Pass in or modify your `EarlyStopping` callback to use any of the'
f' following: `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:

View File

@ -171,7 +171,10 @@ class LoggerConnector:
return result
def _track_callback_metrics(self, eval_results, using_eval_result):
if len(eval_results) > 0 and eval_results[0] is None:
if (
len(eval_results) > 0 and
(eval_results[0] is None or not isinstance(eval_results[0], Result))
):
return
if using_eval_result:

View File

@ -43,7 +43,13 @@ class OptimizerConnector:
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
try:
monitor_key = lr_scheduler['monitor']
except KeyError as e:
m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \
"monitor=. For example:" \
"return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}"
raise MisconfigurationException(m)
if monitor_metrics is not None:
monitor_val = monitor_metrics.get(monitor_key)
@ -54,7 +60,7 @@ class OptimizerConnector:
avail_metrics = ','.join(list(self.trainer.logger_connector.callback_metrics.keys()))
raise MisconfigurationException(
f'ReduceLROnPlateau conditioned on metric {monitor_key}'
f' which is not available. Available metrics are: {avail_metrics}.'
f' which is not available. Available metrics are: [{avail_metrics}].'
' Condition can be set using `monitor` key in lr scheduler dict'
)
# update LR

View File

@ -228,6 +228,10 @@ class EvaluationLoop(object):
if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
result = model._results
if len(result) > 0 and eval_results is None:
eval_results = result.get_epoch_log_metrics()
if not isinstance(eval_results, list):
eval_results = [eval_results]

View File

@ -50,9 +50,10 @@ class TrainerOptimizersMixin(ABC):
# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
monitor = optim_conf.get('monitor', None)
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
lr_schedulers = self.configure_schedulers([lr_scheduler], monitor)
else:
lr_schedulers = []
return [optimizer], lr_schedulers, []
@ -94,13 +95,18 @@ class TrainerOptimizersMixin(ABC):
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
def configure_schedulers(self, schedulers: list):
def configure_schedulers(self, schedulers: list, monitor: str = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = {'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
default_config = {
'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False
} # most often not ReduceLROnPlateau scheduler
if monitor is not None:
default_config['monitor'] = monitor
for scheduler in schedulers:
if isinstance(scheduler, dict):
if 'scheduler' not in scheduler:

View File

@ -531,6 +531,11 @@ class TrainLoop:
# TODO: add outputs to batches
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.trainer.logger_connector.log_train_step_metrics(batch_output)
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
@ -538,11 +543,6 @@ class TrainLoop:
if should_check_val:
self.trainer.run_evaluation(test_mode=False)
# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.trainer.logger_connector.log_train_step_metrics(batch_output)
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------

View File

@ -2,3 +2,4 @@
from tests.base.datasets import TrialMNIST
from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate
from tests.base.simple_model import SimpleModule

View File

@ -51,7 +51,7 @@ class ValidationEpochEndVariations(ABC):
val_loss_mean = val_loss_mean.item()
val_acc_mean = val_acc_mean.item()
metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

View File

@ -0,0 +1,85 @@
import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset
from typing import Optional
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class SimpleModule(LightningModule):
def __init__(self, epoch_min_loss_override: Optional[int] = None):
"""LightningModule for testing purposes
Args:
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.epoch_min_loss_override = epoch_min_loss_override
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def validation_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def test_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss}
def training_epoch_end(self, outputs) -> None:
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("avg_loss", avg_loss)
def validation_epoch_end(self, outputs) -> None:
avg_val_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
# For testing purposes allow a nominated epoch to have a low loss
if self.current_epoch == self.epoch_min_loss_override:
avg_val_loss -= 1e10
self.log("avg_val_loss", avg_val_loss)
self.log("checkpoint_on", avg_val_loss)
def test_epoch_end(self, outputs) -> None:
avg_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
self.log("test_loss", avg_loss)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))

View File

@ -37,7 +37,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
"""
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1)
checkpoint_callback = ModelCheckpoint(monitor="early_stop_on", save_top_k=1)
early_stop_callback = EarlyStoppingTestRestore()
trainer = Trainer(
default_root_dir=tmpdir,
@ -159,13 +159,13 @@ def test_early_stopping_functionality(tmpdir):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'val_loss': torch.tensor(val_loss)}
self.log('abc', torch.tensor(val_loss))
model = CurrentModel()
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
early_stop_callback=EarlyStopping(monitor='abc'),
overfit_batches=0.20,
max_epochs=20,
)

View File

@ -25,7 +25,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
tutils.reset_seed()
model = EvalModelTemplate()
checkpoint = ModelCheckpoint(monitor='val_loss', filepath=None, save_top_k=save_top_k)
checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=None, save_top_k=save_top_k)
trainer = Trainer(
default_root_dir=tmpdir,
@ -45,7 +45,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
tutils.reset_seed()
model = EvalModelTemplate()
checkpoint = ModelCheckpoint(filepath=tmpdir, monitor='val_loss', save_top_k=save_top_k)
checkpoint = ModelCheckpoint(filepath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k)
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
trainer.fit(model)
@ -124,7 +124,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='val_loss', expected_count=num_epochs, save_top_k=-1)
model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1)
trainer = Trainer(
distributed_backend="ddp_cpu",
num_processes=2,
@ -156,23 +156,23 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=None).format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='').format_checkpoint_name(5, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='.').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
# with ver
ckpt_name = ModelCheckpoint(monitor='val_loss',
ckpt_name = ModelCheckpoint(monitor='early_stop_on',
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
@ -182,7 +182,7 @@ def test_model_checkpoint_save_last(tmpdir):
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
@ -304,7 +304,7 @@ def test_model_checkpoint_topk_all(tmpdir):
seed_everything(1000)
epochs = 2
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="val_loss", save_top_k=-1)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
@ -330,7 +330,7 @@ def test_ckpt_metric_names(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir + "/{val_loss:.2f}"),
)
trainer.fit(model)
@ -390,7 +390,7 @@ def test_ckpt_metric_names_results(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir + "/{val_loss:.2f}"),
)
trainer.fit(model)
@ -413,7 +413,7 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=0, save_last=save_last),
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=0, save_last=save_last),
max_epochs=max_epochs,
)
trainer.fit(model)
@ -426,7 +426,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True
monitor='early_stop_on', filepath=tmpdir, save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,

View File

@ -11,6 +11,7 @@ from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from pytorch_lightning.callbacks import ModelCheckpoint
def test_can_prepare_data(tmpdir):
@ -226,6 +227,7 @@ def test_dm_checkpoint_save(tmpdir):
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on')
)
# fit model

View File

@ -82,15 +82,21 @@ def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class):
log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
if logger_class == TensorBoardLogger:
assert log_metric_names == [(0, ['hp_metric']),
(0, ['epoch', 'val_acc', 'val_loss']),
(0, ['epoch', 'train_some_val']),
(0, ['hp_metric']),
(1, ['epoch', 'test_acc', 'test_loss'])]
expected = [
(0, ['hp_metric']),
(0, ['epoch', 'train_some_val']),
(0, ['early_stop_on', 'epoch', 'val_acc']),
(0, ['hp_metric']),
(1, ['epoch', 'test_acc', 'test_loss'])
]
assert log_metric_names == expected
else:
assert log_metric_names == [(0, ['epoch', 'val_acc', 'val_loss']),
(0, ['epoch', 'train_some_val']),
(1, ['epoch', 'test_acc', 'test_loss'])]
expected = [
(0, ['epoch', 'train_some_val']),
(0, ['early_stop_on', 'epoch', 'val_acc']),
(1, ['epoch', 'test_acc', 'test_loss'])
]
assert log_metric_names == expected
@pytest.mark.parametrize("logger_class", [

View File

@ -10,6 +10,7 @@ from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities import APEX_AVAILABLE
@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@ -170,6 +171,7 @@ def test_amp_without_apex(tmpdir):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""
os.environ['PL_DEV_DEBUG'] = '1'

View File

@ -162,7 +162,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1),
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_top_k=-1),
default_root_dir=tmpdir,
)

View File

View File

@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
assert len(trainer.logger_connector.callback_metrics) in [7, 8]
# make sure correct steps were called
assert model.validation_step_called
@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [10, 11]
assert len(trainer.logger_connector.callback_metrics) in [8, 9]
# make sure correct steps were called
assert model.validation_step_called
@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [10, 11]
assert len(trainer.logger_connector.callback_metrics) in [8, 9]
# make sure correct steps were called
assert model.validation_step_called
@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [11, 12]
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
# make sure correct steps were called
assert model.validation_step_called

View File

View File

@ -2,7 +2,9 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule
import os
import torch
@ -141,5 +143,18 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
# we don't want to enable val metrics during steps because it is not something that users should do
callback_metrics = set(trainer.callback_metrics.keys())
callback_metrics.remove('debug_epoch')
expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'epoch_b', 'epoch_d', 'epoch_f', 'f', 'g', 'step_b'}
assert expected_cb_metrics == callback_metrics
def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
checkpoint_callback = callbacks.ModelCheckpoint(save_top_k=1, monitor="avg_val_loss")
trainer = Trainer(
max_epochs=epoch_min_loss_override + 2,
logger=False,
checkpoint_callback=checkpoint_callback,
)
trainer.fit(model)

View File

@ -3,6 +3,7 @@ import torch
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_optimizer_with_scheduling(tmpdir):
@ -111,12 +112,36 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir):
'lr for optimizer 2 not adjusted correctly'
def test_reduce_lr_on_plateau_scheduling(tmpdir):
def test_reduce_lr_on_plateau_scheduling_missing_monitor(tmpdir):
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.configure_optimizers = model.configure_optimizers__reduce_lr_on_plateau
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
)
m = '.*ReduceLROnPlateau requires returning a dict from configure_optimizers.*'
with pytest.raises(MisconfigurationException, match=m):
trainer.fit(model)
def test_reduce_lr_on_plateau_scheduling(tmpdir):
hparams = EvalModelTemplate.get_default_hparams()
class TestModel(EvalModelTemplate):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'monitor': 'early_stop_on'}
model = TestModel(**hparams)
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
@ -128,7 +153,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir):
assert results == 1
assert trainer.lr_schedulers[0] == \
dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss',
dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='early_stop_on',
interval='epoch', frequency=1, reduce_on_plateau=True), \
'lr schduler was not correctly converted to dict'
@ -167,7 +192,7 @@ def test_optimizer_return_options():
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
frequency=1, reduce_on_plateau=False)
# opt single dictionary
model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a}
@ -175,7 +200,7 @@ def test_optimizer_return_options():
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
frequency=1, reduce_on_plateau=False)
# opt multiple dictionaries with frequencies
model.configure_optimizers = lambda: (
@ -186,7 +211,7 @@ def test_optimizer_return_options():
assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
frequency=1, reduce_on_plateau=False)
assert freq == [1, 5]

View File

@ -432,7 +432,7 @@ def test_model_checkpoint_only_weights(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True),
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_weights_only=True),
)
# fit model
result = trainer.fit(model)
@ -508,7 +508,7 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve
max_epochs=2,
limit_train_batches=0.65,
limit_val_batches=1,
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1),
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_top_k=-1),
default_root_dir=tmpdir,
early_stop_callback=False,
val_check_interval=1.,
@ -665,7 +665,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(monitor='val_loss', save_top_k=save_top_k),
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', save_top_k=save_top_k),
)
trainer.fit(model)
if ckpt_path == 'best':
@ -898,6 +898,7 @@ def test_gradient_clipping_fp16(tmpdir):
trainer.fit(model)
def test_gpu_choice(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,