tests for val step flow and logging (#3731)

* ref: test val epoch end

* ref: test val epoch end

* ref: test val epoch end

* ref: test log dict

* ref: test log dict

* ref: test log dict

* ref: test log dict
This commit is contained in:
William Falcon 2020-09-29 22:12:56 -04:00 committed by GitHub
parent 3dcf7130c5
commit b3be8022bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 263 additions and 74 deletions

View File

@ -32,6 +32,7 @@ from pytorch_lightning.core.step_result import EvalResult, TrainResult
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
@ -216,7 +217,15 @@ class LightningModule(
if self._results is not None:
# in any epoch end can't log step metrics (only epoch metric)
if 'epoch_end' in self._current_fx_name and on_step:
on_step = False
m = f'on_step=True cannot be used on {self._current_fx_name} method'
raise MisconfigurationException(m)
if 'epoch_end' in self._current_fx_name and on_epoch == False:
m = f'on_epoch cannot be False when called from the {self._current_fx_name} method'
raise MisconfigurationException(m)
# add log_dict
# TODO: if logged twice fail with crash
# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
@ -238,6 +247,60 @@ class LightningModule(
sync_dist_group
)
def log_dict(
self,
dictionary: dict,
prog_bar: bool = False,
logger: bool = True,
on_step: Union[None, bool] = None,
on_epoch: Union[None, bool] = None,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
):
"""
Log a dictonary of values at once
Example::
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
self.log_dict(values)
Args:
dictionary: key value pairs (str, tensors)
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across
sync_dist_group: the ddp group:
"""
for k, v in dictionary.items():
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
sync_dist_op=sync_dist_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
)
def __auto_choose_log_on_step(self, on_step):
if on_step is None:
if self._current_fx_name in {'training_step', 'training_step_end'}:

View File

@ -131,7 +131,10 @@ class Result(Dict):
# if user requests both step and epoch, then we split the metric in two automatically
# one will be logged per step. the other per epoch
was_forked = False
if on_step and on_epoch:
was_forked = True
# set step version
step_name = f'step_{name}'
self.__set_meta(
@ -144,6 +147,7 @@ class Result(Dict):
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
)
self.__setitem__(step_name, value)
@ -159,6 +163,7 @@ class Result(Dict):
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
)
self.__setitem__(epoch_name, value)
@ -173,6 +178,7 @@ class Result(Dict):
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked
)
# set the value
@ -189,6 +195,7 @@ class Result(Dict):
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool
):
# set the meta for the item
meta_value = value
@ -201,6 +208,7 @@ class Result(Dict):
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked
)
self['meta'][name] = meta
@ -222,9 +230,10 @@ class Result(Dict):
return result
def get_batch_log_metrics(self) -> dict:
def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
@ -232,6 +241,10 @@ class Result(Dict):
for k, options in meta.items():
if k == '_internal':
continue
if options['forked'] and not include_forked_originals:
continue
if options['logger'] and options['on_step']:
result[k] = self[k]
return result
@ -264,7 +277,7 @@ class Result(Dict):
result[k] = self[k]
return result
def get_batch_pbar_metrics(self):
def get_batch_pbar_metrics(self, include_forked_originals=True):
"""
Gets the metrics to log at the end of the batch step
"""
@ -274,6 +287,9 @@ class Result(Dict):
for k, options in meta.items():
if k == '_internal':
continue
if options['forked'] and not include_forked_originals:
continue
if options['prog_bar'] and options['on_step']:
result[k] = self[k]
return result

View File

@ -202,6 +202,7 @@ class EvaluationLoop(object):
if self.testing:
if is_overridden('test_epoch_end', model=model):
model._current_fx_name = 'test_epoch_end'
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)
@ -210,6 +211,7 @@ class EvaluationLoop(object):
else:
if is_overridden('validation_epoch_end', model=model):
model._current_fx_name = 'validation_epoch_end'
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)
@ -314,8 +316,8 @@ class EvaluationLoop(object):
self.__log_result_step_metrics(output, batch_idx)
def __log_result_step_metrics(self, output, batch_idx):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics
step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False)
step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False)
if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph

View File

@ -52,8 +52,6 @@ def test__validation_step__log(tmpdir):
'b',
'step_b/epoch_0',
'step_b/epoch_1',
'b/epoch_0',
'b/epoch_1',
'epoch_b',
'epoch',
}
@ -67,7 +65,7 @@ def test__validation_step__log(tmpdir):
assert expected_cb_metrics == callback_metrics
def test__validation_step__epoch_end__log(tmpdir):
def test__validation_step__step_end__epoch_end__log(tmpdir):
"""
Tests that validation_step can log
"""
@ -88,16 +86,22 @@ def test__validation_step__epoch_end__log(tmpdir):
self.log('c', acc)
self.log('d', acc, on_step=True, on_epoch=True)
self.validation_step_called = True
return acc
def validation_step_end(self, acc):
self.validation_step_end_called = True
self.log('e', acc)
self.log('f', acc, on_step=True, on_epoch=True)
return ['random_thing']
def validation_epoch_end(self, outputs):
self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True)
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
self.validation_epoch_end_called = True
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
model = TestModel()
model.validation_step_end = None
trainer = Trainer(
default_root_dir=tmpdir,
@ -110,38 +114,32 @@ def test__validation_step__epoch_end__log(tmpdir):
trainer.fit(model)
# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'epoch',
'a',
'b',
'step_b',
'epoch_b',
'c',
'd',
'd/epoch_0',
'd/epoch_1',
'step_d/epoch_0',
'step_d/epoch_1',
'epoch_d',
'e',
'epoch_e',
'epoch',
'f',
'step_f/epoch_0',
'step_f/epoch_1',
'epoch_f',
'g',
}
logged_metrics = set(trainer.logged_metrics.keys())
assert expected_logged_metrics == logged_metrics
# we don't want to enable val metrics during steps because it is not something that users should do
expected_cb_metrics = {
'a',
'b',
'step_b',
'epoch_b',
'c',
'd',
'epoch_d',
'e',
'epoch_e',
}
progress_bar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = set()
assert expected_pbar_metrics == progress_bar_metrics
# we don't want to enable val metrics during steps because it is not something that users should do
callback_metrics = set(trainer.callback_metrics.keys())
expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'epoch_b', 'epoch_d', 'epoch_f', 'f', 'g', 'step_b'}
assert expected_cb_metrics == callback_metrics

View File

@ -17,12 +17,35 @@ def test__training_step__log(tmpdir):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('step_acc', acc, on_step=True, on_epoch=False)
self.log('epoch_acc', acc, on_step=False, on_epoch=True)
self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True)
self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False)
self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
# -----------
# default
# -----------
self.log('default', acc)
# -----------
# logger
# -----------
# on_step T on_epoch F
self.log('l_s', acc, on_step=True, on_epoch=False, prog_bar=False, logger=True)
# on_step F on_epoch T
self.log('l_e', acc, on_step=False, on_epoch=True, prog_bar=False, logger=True)
# on_step T on_epoch T
self.log('l_se', acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
# -----------
# pbar
# -----------
# on_step T on_epoch F
self.log('p_s', acc, on_step=True, on_epoch=False, prog_bar=True, logger=False)
# on_step F on_epoch T
self.log('p_e', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
# on_step T on_epoch T
self.log('p_se', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
self.training_step_called = True
return acc
@ -46,19 +69,38 @@ def test__training_step__log(tmpdir):
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure all the metrics are available for callbacks
metrics = [
'step_acc',
'epoch_acc',
'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc',
'pbar_step_acc',
'pbar_epoch_acc',
'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc',
]
expected_metrics = set(metrics + ['debug_epoch'])
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'epoch',
'default',
'l_e',
'l_s',
'l_se',
'step_l_se',
'epoch_l_se',
}
assert logged_metrics == expected_logged_metrics
pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {
'p_e',
'p_s',
'p_se',
'step_p_se',
'epoch_p_se',
}
assert pbar_metrics == expected_pbar_metrics
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_metrics == callback_metrics
callback_metrics.remove('debug_epoch')
expected_callback_metrics = set()
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
expected_callback_metrics.remove('epoch')
assert callback_metrics == expected_callback_metrics
def test__training_step__epoch_end__log(tmpdir):
@ -69,22 +111,17 @@ def test__training_step__epoch_end__log(tmpdir):
class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
self.training_step_called = True
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('step_acc', acc, on_step=True, on_epoch=False)
self.log('epoch_acc', acc, on_step=False, on_epoch=True)
self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True)
self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False)
self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
self.training_step_called = True
self.log('a', acc, on_step=True, on_epoch=True)
self.log_dict({'a1': acc, 'a2': acc})
return acc
def training_epoch_end(self, outputs):
self.training_epoch_end_called = True
# logging is independent of epoch_end loops
self.log('custom_epoch_end_metric', torch.tensor(37.2))
self.log('b1', outputs[0]['loss'])
self.log('b', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True)
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
@ -108,27 +145,100 @@ def test__training_step__epoch_end__log(tmpdir):
assert model.training_epoch_end_called
# make sure all the metrics are available for callbacks
metrics = [
'step_acc',
'epoch_acc',
'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc',
'pbar_step_acc',
'pbar_epoch_acc',
'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc',
'custom_epoch_end_metric'
]
expected_metrics = set(metrics + ['debug_epoch'])
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'epoch',
'a',
'step_a',
'epoch_a',
'b',
'b1',
'a1',
'a2'
}
assert logged_metrics == expected_logged_metrics
pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {
'b',
}
assert pbar_metrics == expected_pbar_metrics
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_metrics == callback_metrics
callback_metrics.remove('debug_epoch')
expected_callback_metrics = set()
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
expected_callback_metrics.remove('epoch')
assert callback_metrics == expected_callback_metrics
# verify global steps were correctly called
# epoch 0
assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0
assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1
assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1
def test__training_step__step_end__epoch_end__log(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'
# epoch 1
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3
class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
self.training_step_called = True
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('a', acc, on_step=True, on_epoch=True)
return acc
def training_step_end(self, out):
self.training_step_end_called = True
self.log('b', out, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return out
def training_epoch_end(self, outputs):
self.training_epoch_end_called = True
self.log('c', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True)
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
model = TestModel()
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert model.training_epoch_end_called
# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'a',
'step_a',
'epoch_a',
'b',
'step_b',
'epoch_b',
'c',
'epoch',
}
assert logged_metrics == expected_logged_metrics
pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {'b', 'c', 'epoch_b', 'step_b'}
assert pbar_metrics == expected_pbar_metrics
callback_metrics = set(trainer.callback_metrics.keys())
callback_metrics.remove('debug_epoch')
expected_callback_metrics = set()
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
expected_callback_metrics.remove('epoch')
assert callback_metrics == expected_callback_metrics