ref: enable self.log from val step (#3701)
* .log in eval * ref * ref: enable self.log in val step
This commit is contained in:
parent
2ecaa2a8be
commit
cdd7266cd8
|
@ -87,7 +87,7 @@ class LoggerConnector:
|
|||
self.trainer.logger.save()
|
||||
|
||||
# track the logged metrics
|
||||
self.logged_metrics = scalar_metrics
|
||||
self.logged_metrics.update(scalar_metrics)
|
||||
self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics)
|
||||
|
||||
def add_progress_bar_metrics(self, metrics):
|
||||
|
@ -191,9 +191,8 @@ class LoggerConnector:
|
|||
|
||||
return eval_loop_results
|
||||
|
||||
def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
|
||||
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
|
||||
early_stopping_accumulator, num_optimizers)
|
||||
def on_train_epoch_end(self, epoch_output):
|
||||
pass
|
||||
|
||||
def log_train_epoch_end_metrics(self,
|
||||
epoch_output,
|
||||
|
@ -413,7 +412,7 @@ class LoggerConnector:
|
|||
|
||||
return gathered_epoch_outputs
|
||||
|
||||
def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
|
||||
def log_train_step_metrics(self, batch_idx, batch_output):
|
||||
# when metrics should be logged
|
||||
should_log_metrics = (batch_idx + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop
|
||||
if should_log_metrics or self.trainer.fast_dev_run:
|
||||
|
|
|
@ -245,6 +245,11 @@ class EvaluationLoop(object):
|
|||
return eval_results
|
||||
|
||||
def on_evaluation_batch_start(self, *args, **kwargs):
|
||||
# reset the result of the PL module
|
||||
model = self.trainer.get_model()
|
||||
model._results = Result()
|
||||
model._current_fx_name = 'evaluation_step'
|
||||
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
|
||||
else:
|
||||
|
@ -273,11 +278,18 @@ class EvaluationLoop(object):
|
|||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
|
||||
def log_step_metrics(self, output, batch_idx):
|
||||
def log_evaluation_step_metrics(self, output, batch_idx):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
results = self.trainer.get_model()._results
|
||||
self.__log_result_step_metrics(results, batch_idx)
|
||||
|
||||
# TODO: deprecate at 1.0
|
||||
if isinstance(output, EvalResult):
|
||||
self.__log_result_step_metrics(output, batch_idx)
|
||||
|
||||
def __log_result_step_metrics(self, output, batch_idx):
|
||||
step_log_metrics = output.batch_log_metrics
|
||||
step_pbar_metrics = output.batch_pbar_metrics
|
||||
|
||||
|
|
|
@ -436,7 +436,7 @@ class Trainer(
|
|||
|
||||
# clean up
|
||||
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.log_step_metrics(output, batch_idx)
|
||||
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
|
||||
|
||||
# track epoch level metrics
|
||||
if output is not None:
|
||||
|
|
|
@ -542,7 +542,7 @@ class TrainLoop:
|
|||
# -----------------------------------------
|
||||
# SAVE METRICS TO LOGGERS
|
||||
# -----------------------------------------
|
||||
self.trainer.logger_connector.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
|
||||
self.trainer.logger_connector.log_train_step_metrics(batch_idx, batch_output)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
|
@ -573,14 +573,17 @@ class TrainLoop:
|
|||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
||||
# process epoch outputs
|
||||
self.trainer.logger_connector.on_train_epoch_end(
|
||||
# log epoch metrics
|
||||
self.trainer.logger_connector.log_train_epoch_end_metrics(
|
||||
epoch_output,
|
||||
self.checkpoint_accumulator,
|
||||
self.early_stopping_accumulator,
|
||||
self.num_optimizers
|
||||
)
|
||||
|
||||
# hook
|
||||
self.trainer.logger_connector.on_train_epoch_end(epoch_output)
|
||||
|
||||
# when no val loop is present or fast-dev-run still need to call checkpoints
|
||||
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
|
||||
|
||||
|
@ -704,6 +707,7 @@ class TrainLoop:
|
|||
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
|
||||
|
||||
# track all metrics for callbacks
|
||||
# TODO: is this needed?
|
||||
self.trainer.logger_connector.callback_metrics.update(
|
||||
{k: v for d in batch_callback_metrics for k, v in d.items() if v is not None}
|
||||
)
|
||||
|
|
|
@ -215,3 +215,59 @@ def test_training_step_dict(tmpdir):
|
|||
# epoch 1
|
||||
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
|
||||
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
|
||||
|
||||
|
||||
def test_validation_step_logging(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
class TestModel(DeterministicModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('train_step_acc', acc, on_step=True, on_epoch=True)
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('val_step_acc', acc, on_step=True, on_epoch=True)
|
||||
self.training_step_called = True
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
||||
model = TestModel()
|
||||
model.validation_step_end = None
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=2,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
expected_logged_metrics = {
|
||||
'epoch',
|
||||
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
|
||||
'val_step_acc/epoch_0', 'val_step_acc/epoch_1',
|
||||
'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1',
|
||||
}
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
assert expected_logged_metrics == logged_metrics
|
||||
|
||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||
expected_cb_metrics = [
|
||||
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
|
||||
]
|
||||
expected_cb_metrics = set(expected_cb_metrics)
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
assert expected_cb_metrics == callback_metrics
|
||||
|
|
Loading…
Reference in New Issue