ref: enable self.log for eval loop metrics (#3715)
* ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test val epoch end
This commit is contained in:
parent
c41ea86b35
commit
f42ea303c9
|
@ -155,8 +155,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
value: Any,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = True,
|
||||
on_epoch: bool = True,
|
||||
on_step: Union[None, bool] = None,
|
||||
on_epoch: Union[None, bool] = None,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
|
@ -190,8 +190,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
value: value name
|
||||
prog_bar: if True logs to the progress base
|
||||
logger: if True logs to the logger
|
||||
on_step: if True logs the output of validation_step or test_step
|
||||
on_epoch: if True, logs the output of the training loop aggregated
|
||||
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
|
||||
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
|
||||
reduce_fx: Torch.mean by default
|
||||
tbptt_reduce_fx: function to reduce on truncated back prop
|
||||
tbptt_pad_token: token to use for padding
|
||||
|
@ -205,6 +205,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
if 'epoch_end' in self._current_fx_name and on_step:
|
||||
on_step = False
|
||||
|
||||
# set the default depending on the fx_name
|
||||
on_step = self.__auto_choose_log_on_step(on_step)
|
||||
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
||||
|
||||
self._results.log(
|
||||
name,
|
||||
value,
|
||||
|
@ -221,6 +225,30 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
sync_dist_group
|
||||
)
|
||||
|
||||
def __auto_choose_log_on_step(self, on_step):
|
||||
if on_step is None:
|
||||
if self._current_fx_name in {'training_step', 'training_step_end'}:
|
||||
on_step = True
|
||||
elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end',
|
||||
'evaluation_epoch_end', 'training_epoch_end'}:
|
||||
on_step = False
|
||||
else:
|
||||
on_step = False
|
||||
|
||||
return on_step
|
||||
|
||||
def __auto_choose_log_on_epoch(self, on_epoch):
|
||||
if on_epoch is None:
|
||||
if self._current_fx_name in {'training_step', 'training_step_end'}:
|
||||
on_epoch = False
|
||||
elif self._current_fx_name in {'evaluation_step', 'evaluation_step_end',
|
||||
'evaluation_epoch_end', 'training_epoch_end'}:
|
||||
on_epoch = True
|
||||
else:
|
||||
on_epoch = True
|
||||
|
||||
return on_epoch
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
r"""
|
||||
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
|
||||
|
|
|
@ -352,9 +352,14 @@ class Result(Dict):
|
|||
@classmethod
|
||||
def reduce_on_epoch_end(cls, outputs):
|
||||
# get the batch sizes for all outputs
|
||||
batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1)
|
||||
batch_sizes = []
|
||||
meta = {}
|
||||
for x in outputs:
|
||||
batch_sizes.append(x.get_batch_sizes())
|
||||
meta.update(x['meta'])
|
||||
|
||||
batch_sizes = torch.stack(batch_sizes).view(-1)
|
||||
|
||||
meta = outputs[0]['meta']
|
||||
result = cls()
|
||||
result = recursive_gather(outputs, result)
|
||||
recursive_stack(result)
|
||||
|
@ -371,6 +376,8 @@ class Result(Dict):
|
|||
reduced_val = fx(result[k])
|
||||
|
||||
result[k] = reduced_val
|
||||
else:
|
||||
del result[k]
|
||||
|
||||
result['meta'] = meta
|
||||
return result
|
||||
|
@ -871,7 +878,7 @@ class EvalResult(Result):
|
|||
|
||||
|
||||
def weighted_mean(result, weights):
|
||||
weights = weights.to(result.device)
|
||||
weights = weights.to(result.device)[:result.size(0)]
|
||||
numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
|
||||
result = numerator / weights.sum().float()
|
||||
return result
|
||||
|
|
|
@ -20,6 +20,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pprint import pprint
|
||||
from typing import Iterable
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class LoggerConnector:
|
||||
|
@ -29,6 +30,7 @@ class LoggerConnector:
|
|||
self.callback_metrics = {}
|
||||
self.logged_metrics = {}
|
||||
self.progress_bar_metrics = {}
|
||||
self.eval_loop_results = []
|
||||
|
||||
def on_trainer_init(self, logger, log_save_interval, row_log_interval):
|
||||
# logging
|
||||
|
@ -100,11 +102,69 @@ class LoggerConnector:
|
|||
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
|
||||
|
||||
def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
|
||||
# TODO: merge both functions?
|
||||
self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result)
|
||||
return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode)
|
||||
self._track_callback_metrics(eval_results, using_eval_result)
|
||||
self._log_on_evaluation_epoch_end_metrics()
|
||||
|
||||
def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
|
||||
# TODO: deprecate parts of this for 1.0 (when removing results)
|
||||
self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode)
|
||||
|
||||
# get the final loop results
|
||||
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
|
||||
return eval_loop_results
|
||||
|
||||
def _get_evaluate_epoch_results(self, test_mode):
|
||||
# log results of test
|
||||
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
|
||||
print('-' * 80)
|
||||
for result_idx, results in enumerate(self.eval_loop_results):
|
||||
print(f'DATALOADER:{result_idx} TEST RESULTS')
|
||||
pprint(results)
|
||||
print('-' * 80)
|
||||
|
||||
results = self.eval_loop_results
|
||||
|
||||
# clear mem
|
||||
self.eval_loop_results = []
|
||||
return results
|
||||
|
||||
def _log_on_evaluation_epoch_end_metrics(self):
|
||||
step_metrics = self.trainer.evaluation_loop.step_metrics
|
||||
|
||||
# clear mem
|
||||
self.trainer.evaluation_loop.step_metrics = []
|
||||
|
||||
num_loaders = len(step_metrics)
|
||||
|
||||
# process metrics per dataloader
|
||||
for dl_idx, dl_metrics in enumerate(step_metrics):
|
||||
if len(dl_metrics) == 0:
|
||||
continue
|
||||
|
||||
reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
|
||||
# make the keys 'k/dl'
|
||||
reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders)
|
||||
|
||||
# track the metrics
|
||||
logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
|
||||
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
|
||||
self.logged_metrics.update(logger_metrics)
|
||||
self.progress_bar_metrics.update(pbar_metrics)
|
||||
|
||||
# enable the metrics to be monitored
|
||||
self.callback_metrics.update(logger_metrics)
|
||||
self.callback_metrics.update(pbar_metrics)
|
||||
|
||||
# track the final results for the dataloader
|
||||
self.eval_loop_results.append(deepcopy(self.callback_metrics))
|
||||
|
||||
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
|
||||
if num_loaders == 1:
|
||||
return metrics
|
||||
|
||||
result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
|
||||
return result
|
||||
|
||||
def _track_callback_metrics(self, eval_results, using_eval_result):
|
||||
if len(eval_results) > 0 and eval_results[0] is None:
|
||||
return
|
||||
|
||||
|
@ -141,11 +201,10 @@ class LoggerConnector:
|
|||
flat['early_stop_on'] = flat['val_loss']
|
||||
self.trainer.logger_connector.callback_metrics.update(flat)
|
||||
|
||||
def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
|
||||
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
eval_loop_results = []
|
||||
if eval_results is not None and len(eval_results) > 0:
|
||||
|
||||
# in eval, the user may return something at every validation step without final reduction
|
||||
|
@ -179,17 +238,7 @@ class LoggerConnector:
|
|||
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
|
||||
|
||||
if len(dataloader_result_metrics) > 0:
|
||||
eval_loop_results.append(dataloader_result_metrics)
|
||||
|
||||
# log results of test
|
||||
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
|
||||
print('-' * 80)
|
||||
for result_idx, results in enumerate(eval_loop_results):
|
||||
print(f'DATALOADER:{result_idx} TEST RESULTS')
|
||||
pprint(results)
|
||||
print('-' * 80)
|
||||
|
||||
return eval_loop_results
|
||||
self.eval_loop_results.append(dataloader_result_metrics)
|
||||
|
||||
def on_train_epoch_end(self, epoch_output):
|
||||
pass
|
||||
|
|
|
@ -23,6 +23,7 @@ class EvaluationLoop(object):
|
|||
self.trainer = trainer
|
||||
self.testing = False
|
||||
self.outputs = []
|
||||
self.step_metrics = []
|
||||
self.predictions = None
|
||||
self.max_batches = None
|
||||
|
||||
|
@ -278,14 +279,21 @@ class EvaluationLoop(object):
|
|||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
|
||||
def log_evaluation_step_metrics(self, output, batch_idx):
|
||||
def log_evaluation_step_metrics(self, batch, batch_idx):
|
||||
results = self.trainer.get_model()._results
|
||||
if len(results) == 1:
|
||||
return None
|
||||
|
||||
results.track_batch_size(len(batch))
|
||||
self.__log_result_step_metrics(results, batch_idx)
|
||||
|
||||
return results
|
||||
|
||||
# TODO: deprecate at 1.0
|
||||
def log_evaluation_step_metrics_legacy(self, output, batch_idx):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
results = self.trainer.get_model()._results
|
||||
self.__log_result_step_metrics(results, batch_idx)
|
||||
|
||||
# TODO: deprecate at 1.0
|
||||
if isinstance(output, EvalResult):
|
||||
self.__log_result_step_metrics(output, batch_idx)
|
||||
|
||||
|
|
|
@ -413,6 +413,7 @@ class Trainer(
|
|||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
# bookkeeping
|
||||
dl_outputs = []
|
||||
dl_step_metrics = []
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
|
@ -436,13 +437,22 @@ class Trainer(
|
|||
|
||||
# clean up
|
||||
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
|
||||
|
||||
# track epoch level metrics
|
||||
# TODO: deprecate 1.0
|
||||
self.evaluation_loop.log_evaluation_step_metrics_legacy(output, batch_idx)
|
||||
|
||||
# log step metrics
|
||||
step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx)
|
||||
|
||||
if step_metrics is not None:
|
||||
dl_step_metrics.append(step_metrics)
|
||||
|
||||
# track epoch level outputs
|
||||
if output is not None:
|
||||
dl_outputs.append(output)
|
||||
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
self.evaluation_loop.step_metrics.append(dl_step_metrics)
|
||||
|
||||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
|
|
@ -17,14 +17,14 @@ def test__validation_step__log(tmpdir):
|
|||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('train_step_acc', acc, on_step=True, on_epoch=True)
|
||||
self.log('a', acc, on_step=True, on_epoch=True)
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('val_step_acc', acc, on_step=True, on_epoch=True)
|
||||
self.log('b', acc, on_step=True, on_epoch=True)
|
||||
self.training_step_called = True
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
|
@ -46,18 +46,103 @@ def test__validation_step__log(tmpdir):
|
|||
|
||||
# make sure all the metrics are available for callbacks
|
||||
expected_logged_metrics = {
|
||||
'a',
|
||||
'step_a',
|
||||
'epoch_a',
|
||||
'b',
|
||||
'step_b/epoch_0',
|
||||
'step_b/epoch_1',
|
||||
'b/epoch_0',
|
||||
'b/epoch_1',
|
||||
'epoch_b',
|
||||
'epoch',
|
||||
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
|
||||
'val_step_acc/epoch_0', 'val_step_acc/epoch_1',
|
||||
'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1',
|
||||
}
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
assert expected_logged_metrics == logged_metrics
|
||||
|
||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||
expected_cb_metrics = [
|
||||
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
|
||||
]
|
||||
expected_cb_metrics = set(expected_cb_metrics)
|
||||
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
|
||||
expected_cb_metrics = {'a', 'b', 'epoch_a', 'epoch_b', 'step_a'}
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
assert expected_cb_metrics == callback_metrics
|
||||
|
||||
|
||||
def test__validation_step__epoch_end__log(tmpdir):
|
||||
"""
|
||||
Tests that validation_step can log
|
||||
"""
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
class TestModel(DeterministicModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('a', acc)
|
||||
self.log('b', acc, on_step=True, on_epoch=True)
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('c', acc)
|
||||
self.log('d', acc, on_step=True, on_epoch=True)
|
||||
self.validation_step_called = True
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True)
|
||||
self.validation_epoch_end_called = True
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
||||
model = TestModel()
|
||||
model.validation_step_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=2,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
expected_logged_metrics = {
|
||||
'a',
|
||||
'b',
|
||||
'step_b',
|
||||
'epoch_b',
|
||||
'c',
|
||||
'd',
|
||||
'd/epoch_0',
|
||||
'd/epoch_1',
|
||||
'step_d/epoch_0',
|
||||
'step_d/epoch_1',
|
||||
'epoch_d',
|
||||
'e',
|
||||
'epoch_e',
|
||||
'epoch',
|
||||
}
|
||||
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
assert expected_logged_metrics == logged_metrics
|
||||
|
||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||
expected_cb_metrics = {
|
||||
'a',
|
||||
'b',
|
||||
'step_b',
|
||||
'epoch_b',
|
||||
'c',
|
||||
'd',
|
||||
'epoch_d',
|
||||
'e',
|
||||
'epoch_e',
|
||||
'debug_epoch',
|
||||
}
|
||||
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
assert expected_cb_metrics == callback_metrics
|
||||
|
|
Loading…
Reference in New Issue