spaCy/spacy/training/loggers.py

96 lines
3.2 KiB
Python
Raw Normal View History

2020-08-28 12:08:33 +00:00
from typing import Dict, Any, Tuple, Callable, List
from ..util import registry
2020-08-28 11:55:32 +00:00
from .. import util
from ..errors import Errors
from wasabi import msg
@registry.loggers("spacy.ConsoleLogger.v1")
def console_logger():
def setup_printer(
2020-08-29 11:01:10 +00:00
nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
score_cols = list(nlp.config["training"]["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
table_widths = [3, 6] + loss_widths + score_widths + [6]
table_aligns = ["r" for _ in table_widths]
msg.row(table_header, widths=table_widths)
msg.row(["-" * width for width in table_widths])
def log_step(info: Dict[str, Any]):
try:
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in nlp.pipe_names
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (losses)",
key=str(e),
keys=list(info["losses"].keys()),
)
) from None
2020-09-13 15:39:31 +00:00
scores = []
for col in score_cols:
2020-09-13 17:23:09 +00:00
score = float(info["other_scores"].get(col, 0.0))
2020-09-13 15:39:31 +00:00
if col != "speed":
score *= 100
2020-09-13 17:23:09 +00:00
scores.append("{0:.2f}".format(score))
data = (
[info["epoch"], info["step"]]
+ losses
+ scores
+ ["{0:.2f}".format(float(info["score"]))]
)
msg.row(data, widths=table_widths, aligns=table_aligns)
def finalize():
pass
return log_step, finalize
return setup_printer
@registry.loggers("spacy.WandbLogger.v1")
2020-08-28 12:08:33 +00:00
def wandb_logger(project_name: str, remove_config_values: List[str] = []):
import wandb
console = console_logger()
def setup_logger(
2020-08-29 11:01:10 +00:00
nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate()
2020-08-28 11:55:32 +00:00
config_dot = util.dict_to_dot(config)
2020-08-28 12:06:23 +00:00
for field in remove_config_values:
2020-08-28 11:55:32 +00:00
del config_dot[field]
config = util.dot_to_dict(config_dot)
wandb.init(project=project_name, config=config, reinit=True)
console_log_step, console_finalize = console(nlp)
def log_step(info: Dict[str, Any]):
console_log_step(info)
score = info["score"]
other_scores = info["other_scores"]
losses = info["losses"]
2020-08-28 11:55:32 +00:00
wandb.log({"score": score})
if losses:
wandb.log({f"loss_{k}": v for k, v in losses.items()})
if isinstance(other_scores, dict):
wandb.log(other_scores)
def finalize():
console_finalize()
wandb.join()
return log_step, finalize
return setup_logger