ref: enable self.log for eval loop metrics (#3715)

* ref: test val epoch end

* ref: test val epoch end

* ref: test val epoch end

* ref: test val epoch end

* ref: test val epoch end

* ref: test val epoch end
This commit is contained in:
William Falcon 2020-09-29 02:00:28 -04:00 committed by GitHub
parent c41ea86b35
commit f42ea303c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 227 additions and 40 deletions

View File

@ -155,8 +155,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = True,
on_epoch: bool = True,
on_step: Union[None, bool] = None,
on_epoch: Union[None, bool] = None,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
@ -190,8 +190,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
value: value name
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
@ -205,6 +205,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
if 'epoch_end' in self._current_fx_name and on_step:
on_step = False
# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
self._results.log(
name,
value,
@ -221,6 +225,30 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
sync_dist_group
)
def __auto_choose_log_on_step(self, on_step):
if on_step is None:
if self._current_fx_name in {'training_step', 'training_step_end'}:
on_step = True
elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end',
'evaluation_epoch_end', 'training_epoch_end'}:
on_step = False
else:
on_step = False
return on_step
def __auto_choose_log_on_epoch(self, on_epoch):
if on_epoch is None:
if self._current_fx_name in {'training_step', 'training_step_end'}:
on_epoch = False
elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end',
'evaluation_epoch_end', 'training_epoch_end'}:
on_epoch = True
else:
on_epoch = True
return on_epoch
def forward(self, *args, **kwargs):
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define

View File

@ -352,9 +352,14 @@ class Result(Dict):
@classmethod
def reduce_on_epoch_end(cls, outputs):
# get the batch sizes for all outputs
batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1)
batch_sizes = []
meta = {}
for x in outputs:
batch_sizes.append(x.get_batch_sizes())
meta.update(x['meta'])
batch_sizes = torch.stack(batch_sizes).view(-1)
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
@ -371,6 +376,8 @@ class Result(Dict):
reduced_val = fx(result[k])
result[k] = reduced_val
else:
del result[k]
result['meta'] = meta
return result
@ -871,7 +878,7 @@ class EvalResult(Result):
def weighted_mean(result, weights):
weights = weights.to(result.device)
weights = weights.to(result.device)[:result.size(0)]
numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
result = numerator / weights.sum().float()
return result

View File

@ -20,6 +20,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pprint import pprint
from typing import Iterable
from copy import deepcopy
class LoggerConnector:
@ -29,6 +30,7 @@ class LoggerConnector:
self.callback_metrics = {}
self.logged_metrics = {}
self.progress_bar_metrics = {}
self.eval_loop_results = []
def on_trainer_init(self, logger, log_save_interval, row_log_interval):
# logging
@ -100,11 +102,69 @@ class LoggerConnector:
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
# TODO: merge both functions?
self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result)
return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode)
self._track_callback_metrics(eval_results, using_eval_result)
self._log_on_evaluation_epoch_end_metrics()
def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
# TODO: deprecate parts of this for 1.0 (when removing results)
self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode)
# get the final loop results
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
return eval_loop_results
def _get_evaluate_epoch_results(self, test_mode):
# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)
results = self.eval_loop_results
# clear mem
self.eval_loop_results = []
return results
def _log_on_evaluation_epoch_end_metrics(self):
step_metrics = self.trainer.evaluation_loop.step_metrics
# clear mem
self.trainer.evaluation_loop.step_metrics = []
num_loaders = len(step_metrics)
# process metrics per dataloader
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
continue
reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
# make the keys 'k/dl'
reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders)
# track the metrics
logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
self.logged_metrics.update(logger_metrics)
self.progress_bar_metrics.update(pbar_metrics)
# enable the metrics to be monitored
self.callback_metrics.update(logger_metrics)
self.callback_metrics.update(pbar_metrics)
# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics
result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
return result
def _track_callback_metrics(self, eval_results, using_eval_result):
if len(eval_results) > 0 and eval_results[0] is None:
return
@ -141,11 +201,10 @@ class LoggerConnector:
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode):
if self.trainer.running_sanity_check:
return
eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:
# in eval, the user may return something at every validation step without final reduction
@ -179,17 +238,7 @@ class LoggerConnector:
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)
# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)
return eval_loop_results
self.eval_loop_results.append(dataloader_result_metrics)
def on_train_epoch_end(self, epoch_output):
pass

View File

@ -23,6 +23,7 @@ class EvaluationLoop(object):
self.trainer = trainer
self.testing = False
self.outputs = []
self.step_metrics = []
self.predictions = None
self.max_batches = None
@ -278,14 +279,21 @@ class EvaluationLoop(object):
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
def log_evaluation_step_metrics(self, output, batch_idx):
def log_evaluation_step_metrics(self, batch, batch_idx):
results = self.trainer.get_model()._results
if len(results) == 1:
return None
results.track_batch_size(len(batch))
self.__log_result_step_metrics(results, batch_idx)
return results
# TODO: deprecate at 1.0
def log_evaluation_step_metrics_legacy(self, output, batch_idx):
if self.trainer.running_sanity_check:
return
results = self.trainer.get_model()._results
self.__log_result_step_metrics(results, batch_idx)
# TODO: deprecate at 1.0
if isinstance(output, EvalResult):
self.__log_result_step_metrics(output, batch_idx)

View File

@ -413,6 +413,7 @@ class Trainer(
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dl_step_metrics = []
dataloader = self.accelerator_backend.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
@ -436,13 +437,22 @@ class Trainer(
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
# track epoch level metrics
# TODO: deprecate 1.0
self.evaluation_loop.log_evaluation_step_metrics_legacy(output, batch_idx)
# log step metrics
step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx)
if step_metrics is not None:
dl_step_metrics.append(step_metrics)
# track epoch level outputs
if output is not None:
dl_outputs.append(output)
self.evaluation_loop.outputs.append(dl_outputs)
self.evaluation_loop.step_metrics.append(dl_step_metrics)
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))

View File

@ -17,14 +17,14 @@ def test__validation_step__log(tmpdir):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('train_step_acc', acc, on_step=True, on_epoch=True)
self.log('a', acc, on_step=True, on_epoch=True)
self.training_step_called = True
return acc
def validation_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('val_step_acc', acc, on_step=True, on_epoch=True)
self.log('b', acc, on_step=True, on_epoch=True)
self.training_step_called = True
def backward(self, trainer, loss, optimizer, optimizer_idx):
@ -46,18 +46,103 @@ def test__validation_step__log(tmpdir):
# make sure all the metrics are available for callbacks
expected_logged_metrics = {
'a',
'step_a',
'epoch_a',
'b',
'step_b/epoch_0',
'step_b/epoch_1',
'b/epoch_0',
'b/epoch_1',
'epoch_b',
'epoch',
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
'val_step_acc/epoch_0', 'val_step_acc/epoch_1',
'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1',
}
logged_metrics = set(trainer.logged_metrics.keys())
assert expected_logged_metrics == logged_metrics
# we don't want to enable val metrics during steps because it is not something that users should do
expected_cb_metrics = [
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
]
expected_cb_metrics = set(expected_cb_metrics)
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
expected_cb_metrics = {'a', 'b', 'epoch_a', 'epoch_b', 'step_a'}
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_cb_metrics == callback_metrics
def test__validation_step__epoch_end__log(tmpdir):
"""
Tests that validation_step can log
"""
os.environ['PL_DEV_DEBUG'] = '1'
class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('a', acc)
self.log('b', acc, on_step=True, on_epoch=True)
self.training_step_called = True
return acc
def validation_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('c', acc)
self.log('d', acc, on_step=True, on_epoch=True)
self.validation_step_called = True
def validation_epoch_end(self, outputs):
self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True)
self.validation_epoch_end_called = True
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
model = TestModel()
model.validation_step_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)
# make sure all the metrics are available for callbacks
expected_logged_metrics = {
'a',
'b',
'step_b',
'epoch_b',
'c',
'd',
'd/epoch_0',
'd/epoch_1',
'step_d/epoch_0',
'step_d/epoch_1',
'epoch_d',
'e',
'epoch_e',
'epoch',
}
logged_metrics = set(trainer.logged_metrics.keys())
assert expected_logged_metrics == logged_metrics
# we don't want to enable val metrics during steps because it is not something that users should do
expected_cb_metrics = {
'a',
'b',
'step_b',
'epoch_b',
'c',
'd',
'epoch_d',
'e',
'epoch_e',
'debug_epoch',
}
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_cb_metrics == callback_metrics