From 20b0ec5dcf5b97a3c406ec6bd7aa3f32223c63fa Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 23 Sep 2020 10:37:12 +0200 Subject: [PATCH] avoid logging performance of frozen components --- spacy/cli/train.py | 6 ++++-- spacy/training/loggers.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index bf3749c9e..811a3ba86 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -152,7 +152,8 @@ def train( exclude=frozen_components, ) msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") - print_row, finalize_logger = train_logger(nlp) + with nlp.select_pipes(disable=[*frozen_components]): + print_row, finalize_logger = train_logger(nlp) try: progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) @@ -163,7 +164,8 @@ def train( progress.close() print_row(info) if is_best_checkpoint and output_path is not None: - update_meta(T_cfg, nlp, info) + with nlp.select_pipes(disable=[*frozen_components]): + update_meta(T_cfg, nlp, info) with nlp.use_params(optimizer.averages): nlp.to_disk(output_path / "model-best") progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index 92b598033..dddf20169 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -11,9 +11,11 @@ def console_logger(): def setup_printer( nlp: "Language", ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: + # we assume here that only components are enabled that should be trained & logged + logged_pipes = nlp.pipe_names score_cols = list(nlp.config["training"]["score_weights"]) score_widths = [max(len(col), 6) for col in score_cols] - loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] + loss_cols = [f"Loss {pipe}" for pipe in logged_pipes] loss_widths = [max(len(col), 8) for col in loss_cols] table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] table_header = [col.upper() for col in table_header] @@ -26,7 +28,7 @@ def console_logger(): try: losses = [ "{0:.2f}".format(float(info["losses"][pipe_name])) - for pipe_name in nlp.pipe_names + for pipe_name in logged_pipes ] except KeyError as e: raise KeyError(