re-enabled naming metrics in ckpt name (#3060)

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name
This commit is contained in:
William Falcon 2020-08-19 20:34:09 -04:00 committed by GitHub
parent cefc7f7c32
commit 3453bba898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 8 deletions

View File

@ -339,10 +339,11 @@ class ModelCheckpoint(Callback):
self.epoch_last_check = epoch
filepath = self.format_checkpoint_name(epoch, metrics)
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

View File

@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC):
default_root_dir: str
slurm_job_id: int
num_gpus: int
logged_metrics: ...
def configure_logger(self, logger):
if logger is True:
@ -75,6 +76,8 @@ class TrainerLoggingMixin(ABC):
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()
# track the logged metrics
self.logged_metrics = scalar_metrics
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
def add_progress_bar_metrics(self, metrics):

View File

@ -374,6 +374,7 @@ class Trainer(
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.logged_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_test_batches = []

View File

@ -1,4 +1,5 @@
import os
import re
import pickle
import platform
from pathlib import Path
@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
def test_ckpt_metric_names(tmpdir):
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)
trainer.fit(model)
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3
def test_ckpt_metric_names_results(tmpdir):
model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)
trainer.fit(model)
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3

View File

@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 8
assert len(trainer.callback_metrics) == 7
# make sure correct steps were called
assert model.validation_step_called
@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8
# make sure correct steps were called
assert model.validation_step_called
@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8
# make sure correct steps were called
assert model.validation_step_called
@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 10
assert len(trainer.callback_metrics) == 9
# make sure correct steps were called
assert model.validation_step_called

View File

@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir):
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0
# make sure training outputs what is expected
@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0
# make sure training outputs what is expected