ref: enable self.log from val step (#3701)

* .log in eval

* ref

* ref: enable self.log in val step
This commit is contained in:
William Falcon 2020-09-28 10:49:07 -04:00 committed by GitHub
parent 2ecaa2a8be
commit cdd7266cd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 20 deletions

View File

@ -87,7 +87,7 @@ class LoggerConnector:
self.trainer.logger.save()
# track the logged metrics
self.logged_metrics = scalar_metrics
self.logged_metrics.update(scalar_metrics)
self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics)
def add_progress_bar_metrics(self, metrics):
@ -191,9 +191,8 @@ class LoggerConnector:
return eval_loop_results
def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
early_stopping_accumulator, num_optimizers)
def on_train_epoch_end(self, epoch_output):
pass
def log_train_epoch_end_metrics(self,
epoch_output,
@ -413,7 +412,7 @@ class LoggerConnector:
return gathered_epoch_outputs
def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
def log_train_step_metrics(self, batch_idx, batch_output):
# when metrics should be logged
should_log_metrics = (batch_idx + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop
if should_log_metrics or self.trainer.fast_dev_run:

View File

@ -245,6 +245,11 @@ class EvaluationLoop(object):
return eval_results
def on_evaluation_batch_start(self, *args, **kwargs):
# reset the result of the PL module
model = self.trainer.get_model()
model._results = Result()
model._current_fx_name = 'evaluation_step'
if self.testing:
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
else:
@ -273,11 +278,18 @@ class EvaluationLoop(object):
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
def log_step_metrics(self, output, batch_idx):
def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return
results = self.trainer.get_model()._results
self.__log_result_step_metrics(results, batch_idx)
# TODO: deprecate at 1.0
if isinstance(output, EvalResult):
self.__log_result_step_metrics(output, batch_idx)
def __log_result_step_metrics(self, output, batch_idx):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics

View File

@ -436,7 +436,7 @@ class Trainer(
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
# track epoch level metrics
if output is not None:

View File

@ -542,7 +542,7 @@ class TrainLoop:
# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.trainer.logger_connector.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
self.trainer.logger_connector.log_train_step_metrics(batch_idx, batch_output)
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
@ -573,14 +573,17 @@ class TrainLoop:
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
# process epoch outputs
self.trainer.logger_connector.on_train_epoch_end(
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output,
self.checkpoint_accumulator,
self.early_stopping_accumulator,
self.num_optimizers
)
# hook
self.trainer.logger_connector.on_train_epoch_end(epoch_output)
# when no val loop is present or fast-dev-run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
@ -704,6 +707,7 @@ class TrainLoop:
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
# track all metrics for callbacks
# TODO: is this needed?
self.trainer.logger_connector.callback_metrics.update(
{k: v for d in batch_callback_metrics for k, v in d.items() if v is not None}
)

View File

@ -215,3 +215,59 @@ def test_training_step_dict(tmpdir):
# epoch 1
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
def test_validation_step_logging(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'
class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('train_step_acc', acc, on_step=True, on_epoch=True)
self.training_step_called = True
return acc
def validation_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('val_step_acc', acc, on_step=True, on_epoch=True)
self.training_step_called = True
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
model = TestModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)
# make sure all the metrics are available for callbacks
expected_logged_metrics = {
'epoch',
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
'val_step_acc/epoch_0', 'val_step_acc/epoch_1',
'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1',
}
logged_metrics = set(trainer.logged_metrics.keys())
assert expected_logged_metrics == logged_metrics
# we don't want to enable val metrics during steps because it is not something that users should do
expected_cb_metrics = [
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
]
expected_cb_metrics = set(expected_cb_metrics)
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_cb_metrics == callback_metrics