re-enabled naming metrics in ckpt name (#3060)
* re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name
This commit is contained in:
parent
cefc7f7c32
commit
3453bba898
|
@ -339,10 +339,11 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
self.epoch_last_check = epoch
|
||||
|
||||
filepath = self.format_checkpoint_name(epoch, metrics)
|
||||
ckpt_name_metrics = trainer.logged_metrics
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
|
||||
version_cnt = 0
|
||||
while gfile.exists(filepath):
|
||||
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
|
||||
# this epoch called before
|
||||
version_cnt += 1
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC):
|
|||
default_root_dir: str
|
||||
slurm_job_id: int
|
||||
num_gpus: int
|
||||
logged_metrics: ...
|
||||
|
||||
def configure_logger(self, logger):
|
||||
if logger is True:
|
||||
|
@ -75,6 +76,8 @@ class TrainerLoggingMixin(ABC):
|
|||
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
|
||||
self.logger.save()
|
||||
|
||||
# track the logged metrics
|
||||
self.logged_metrics = scalar_metrics
|
||||
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
|
||||
|
||||
def add_progress_bar_metrics(self, metrics):
|
||||
|
|
|
@ -374,6 +374,7 @@ class Trainer(
|
|||
self.batch_idx = 0
|
||||
self.progress_bar_metrics = {}
|
||||
self.callback_metrics = {}
|
||||
self.logged_metrics = {}
|
||||
self.num_training_batches = 0
|
||||
self.num_val_batches = []
|
||||
self.num_test_batches = []
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import pickle
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
|
||||
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
|
||||
assert w0.eq(w1).all()
|
||||
|
||||
|
||||
def test_ckpt_metric_names(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_batches=0.20,
|
||||
progress_bar_refresh_rate=0,
|
||||
limit_train_batches=0.01,
|
||||
limit_val_batches=0.01,
|
||||
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the checkpoint we saved has the metric in the name
|
||||
ckpts = os.listdir(tmpdir)
|
||||
ckpts = [x for x in ckpts if 'val_loss' in x]
|
||||
assert len(ckpts) == 1
|
||||
val = re.sub('[^0-9.]', '', ckpts[0])
|
||||
assert len(val) > 3
|
||||
|
||||
|
||||
def test_ckpt_metric_names_results(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
model.training_step = model.training_step_result_obj
|
||||
model.training_step_end = None
|
||||
model.training_epoch_end = None
|
||||
|
||||
model.validation_step = model.validation_step_result_obj
|
||||
model.validation_step_end = None
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_batches=0.20,
|
||||
progress_bar_refresh_rate=0,
|
||||
limit_train_batches=0.01,
|
||||
limit_val_batches=0.01,
|
||||
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the checkpoint we saved has the metric in the name
|
||||
ckpts = os.listdir(tmpdir)
|
||||
ckpts = [x for x in ckpts if 'val_loss' in x]
|
||||
assert len(ckpts) == 1
|
||||
val = re.sub('[^0-9.]', '', ckpts[0])
|
||||
assert len(val) > 3
|
||||
|
|
|
@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
|
|||
assert k in eval_results[1]
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.callback_metrics) == 8
|
||||
assert len(trainer.callback_metrics) == 7
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
|
|||
assert k in eval_results[1]
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.callback_metrics) == 9
|
||||
assert len(trainer.callback_metrics) == 8
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
|
|||
assert k in eval_results
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.callback_metrics) == 9
|
||||
assert len(trainer.callback_metrics) == 8
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
|
|||
assert k in eval_results
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.callback_metrics) == 10
|
||||
assert len(trainer.callback_metrics) == 9
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir):
|
|||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
|
||||
assert len(trainer.callback_metrics) == 0
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
# make sure training outputs what is expected
|
||||
|
@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
|
|||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
|
||||
assert len(trainer.callback_metrics) == 0
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
# make sure training outputs what is expected
|
||||
|
|
Loading…
Reference in New Issue