diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index b431ecf06..79459a89b 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -11,11 +11,25 @@ if TYPE_CHECKING: from ..language import Language # noqa: F401 +def setup_table( + *, cols: List[str], widths: List[int], max_width: int = 13 +) -> Tuple[List[str], List[int], List[str]]: + final_cols = [] + final_widths = [] + for col, width in zip(cols, widths): + if len(col) > max_width: + col = col[: max_width - 3] + "..." # shorten column if too long + final_cols.append(col.upper()) + final_widths.append(max(len(col), width)) + return final_cols, final_widths, ["r" for _ in final_widths] + + @registry.loggers("spacy.ConsoleLogger.v1") def console_logger(progress_bar: bool = False): def setup_printer( nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: + write = lambda text: stdout.write(f"{text}\n") msg = Printer(no_print=True) # ensure that only trainable components are logged logged_pipes = [ @@ -26,15 +40,14 @@ def console_logger(progress_bar: bool = False): eval_frequency = nlp.config["training"]["eval_frequency"] score_weights = nlp.config["training"]["score_weights"] score_cols = [col for col, value in score_weights.items() if value is not None] - score_widths = [max(len(col), 6) for col in score_cols] loss_cols = [f"Loss {pipe}" for pipe in logged_pipes] - loss_widths = [max(len(col), 8) for col in loss_cols] - table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] - table_header = [col.upper() for col in table_header] - table_widths = [3, 6] + loss_widths + score_widths + [6] - table_aligns = ["r" for _ in table_widths] - stdout.write(msg.row(table_header, widths=table_widths) + "\n") - stdout.write(msg.row(["-" * width for width in table_widths]) + "\n") + spacing = 2 + table_header, table_widths, table_aligns = setup_table( + cols=["E", "#"] + loss_cols + score_cols + ["Score"], + widths=[3, 6] + [8 for _ in loss_cols] + [6 for _ in score_cols] + [6], + ) + write(msg.row(table_header, widths=table_widths, spacing=spacing)) + write(msg.row(["-" * width for width in table_widths], spacing=spacing)) progress = None def log_step(info: Optional[Dict[str, Any]]) -> None: @@ -70,7 +83,9 @@ def console_logger(progress_bar: bool = False): ) if progress is not None: progress.close() - stdout.write(msg.row(data, widths=table_widths, aligns=table_aligns) + "\n") + write( + msg.row(data, widths=table_widths, aligns=table_aligns, spacing=spacing) + ) if progress_bar: # Set disable=None, so that it disables on non-TTY progress = tqdm.tqdm(