2020-10-03 12:57:46 +00:00
|
|
|
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
|
2020-10-03 14:31:58 +00:00
|
|
|
from wasabi import Printer
|
2020-10-03 12:57:46 +00:00
|
|
|
import tqdm
|
|
|
|
import sys
|
2020-08-26 13:24:33 +00:00
|
|
|
|
|
|
|
from ..util import registry
|
2020-08-28 11:55:32 +00:00
|
|
|
from .. import util
|
2020-08-26 13:24:33 +00:00
|
|
|
from ..errors import Errors
|
|
|
|
|
2020-10-03 14:31:58 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ..language import Language # noqa: F401
|
|
|
|
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2020-10-11 10:55:46 +00:00
|
|
|
def setup_table(
|
|
|
|
*, cols: List[str], widths: List[int], max_width: int = 13
|
|
|
|
) -> Tuple[List[str], List[int], List[str]]:
|
|
|
|
final_cols = []
|
|
|
|
final_widths = []
|
|
|
|
for col, width in zip(cols, widths):
|
|
|
|
if len(col) > max_width:
|
|
|
|
col = col[: max_width - 3] + "..." # shorten column if too long
|
|
|
|
final_cols.append(col.upper())
|
|
|
|
final_widths.append(max(len(col), width))
|
|
|
|
return final_cols, final_widths, ["r" for _ in final_widths]
|
|
|
|
|
|
|
|
|
2020-08-26 13:24:33 +00:00
|
|
|
@registry.loggers("spacy.ConsoleLogger.v1")
|
2020-10-03 14:31:58 +00:00
|
|
|
def console_logger(progress_bar: bool = False):
|
2020-08-26 13:24:33 +00:00
|
|
|
def setup_printer(
|
2020-10-03 14:31:58 +00:00
|
|
|
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
|
|
|
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
|
2020-10-11 10:55:46 +00:00
|
|
|
write = lambda text: stdout.write(f"{text}\n")
|
2020-10-03 14:31:58 +00:00
|
|
|
msg = Printer(no_print=True)
|
2020-10-05 15:43:42 +00:00
|
|
|
# ensure that only trainable components are logged
|
|
|
|
logged_pipes = [
|
|
|
|
name
|
|
|
|
for name, proc in nlp.pipeline
|
2020-10-08 19:33:49 +00:00
|
|
|
if hasattr(proc, "is_trainable") and proc.is_trainable
|
2020-10-05 15:43:42 +00:00
|
|
|
]
|
2020-10-03 12:57:46 +00:00
|
|
|
eval_frequency = nlp.config["training"]["eval_frequency"]
|
2020-09-24 09:04:35 +00:00
|
|
|
score_weights = nlp.config["training"]["score_weights"]
|
|
|
|
score_cols = [col for col, value in score_weights.items() if value is not None]
|
2020-09-23 08:37:12 +00:00
|
|
|
loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
|
2020-10-11 10:55:46 +00:00
|
|
|
spacing = 2
|
|
|
|
table_header, table_widths, table_aligns = setup_table(
|
|
|
|
cols=["E", "#"] + loss_cols + score_cols + ["Score"],
|
|
|
|
widths=[3, 6] + [8 for _ in loss_cols] + [6 for _ in score_cols] + [6],
|
|
|
|
)
|
|
|
|
write(msg.row(table_header, widths=table_widths, spacing=spacing))
|
|
|
|
write(msg.row(["-" * width for width in table_widths], spacing=spacing))
|
2020-10-03 12:57:46 +00:00
|
|
|
progress = None
|
|
|
|
|
2020-10-03 14:31:58 +00:00
|
|
|
def log_step(info: Optional[Dict[str, Any]]) -> None:
|
2020-10-03 12:57:46 +00:00
|
|
|
nonlocal progress
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2020-10-03 12:57:46 +00:00
|
|
|
if info is None:
|
|
|
|
# If we don't have a new checkpoint, just return.
|
|
|
|
if progress is not None:
|
|
|
|
progress.update(1)
|
2020-10-03 14:31:58 +00:00
|
|
|
return
|
2020-10-05 14:33:28 +00:00
|
|
|
losses = [
|
|
|
|
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
2020-10-05 15:43:42 +00:00
|
|
|
for pipe_name in logged_pipes
|
2020-10-05 14:33:28 +00:00
|
|
|
]
|
2020-10-03 12:57:46 +00:00
|
|
|
|
2020-09-13 15:39:31 +00:00
|
|
|
scores = []
|
|
|
|
for col in score_cols:
|
2020-09-24 09:04:35 +00:00
|
|
|
score = info["other_scores"].get(col, 0.0)
|
|
|
|
try:
|
|
|
|
score = float(score)
|
|
|
|
except TypeError:
|
|
|
|
err = Errors.E916.format(name=col, score_type=type(score))
|
2020-09-24 09:29:07 +00:00
|
|
|
raise ValueError(err) from None
|
2020-10-03 12:57:46 +00:00
|
|
|
if col != "speed":
|
|
|
|
score *= 100
|
|
|
|
scores.append("{0:.2f}".format(score))
|
|
|
|
|
2020-08-26 13:24:33 +00:00
|
|
|
data = (
|
|
|
|
[info["epoch"], info["step"]]
|
|
|
|
+ losses
|
|
|
|
+ scores
|
|
|
|
+ ["{0:.2f}".format(float(info["score"]))]
|
|
|
|
)
|
2020-10-03 12:57:46 +00:00
|
|
|
if progress is not None:
|
|
|
|
progress.close()
|
2020-10-11 10:55:46 +00:00
|
|
|
write(
|
|
|
|
msg.row(data, widths=table_widths, aligns=table_aligns, spacing=spacing)
|
|
|
|
)
|
2020-10-03 12:57:46 +00:00
|
|
|
if progress_bar:
|
|
|
|
# Set disable=None, so that it disables on non-TTY
|
|
|
|
progress = tqdm.tqdm(
|
2020-10-03 14:31:58 +00:00
|
|
|
total=eval_frequency, disable=None, leave=False, file=stderr
|
2020-10-03 12:57:46 +00:00
|
|
|
)
|
|
|
|
progress.set_description(f"Epoch {info['epoch']+1}")
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2020-10-03 14:31:58 +00:00
|
|
|
def finalize() -> None:
|
2020-08-26 13:24:33 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
return log_step, finalize
|
|
|
|
|
|
|
|
return setup_printer
|
|
|
|
|
|
|
|
|
2021-04-01 17:36:23 +00:00
|
|
|
@registry.loggers("spacy.WandbLogger.v2")
|
|
|
|
def wandb_logger(
|
|
|
|
project_name: str,
|
|
|
|
remove_config_values: List[str] = [],
|
|
|
|
model_log_interval: Optional[int] = None,
|
|
|
|
log_dataset_dir: Optional[str] = None,
|
|
|
|
):
|
2021-02-26 17:00:39 +00:00
|
|
|
try:
|
|
|
|
import wandb
|
2021-06-28 10:03:29 +00:00
|
|
|
# test that these are available
|
|
|
|
from wandb import init, log, join # noqa: F401
|
2021-02-26 17:00:39 +00:00
|
|
|
except ImportError:
|
|
|
|
raise ImportError(Errors.E880)
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2020-10-03 12:57:46 +00:00
|
|
|
console = console_logger(progress_bar=False)
|
2020-08-26 13:24:33 +00:00
|
|
|
|
|
|
|
def setup_logger(
|
2020-10-03 14:31:58 +00:00
|
|
|
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
|
|
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
|
2020-08-26 13:24:33 +00:00
|
|
|
config = nlp.config.interpolate()
|
2020-08-28 11:55:32 +00:00
|
|
|
config_dot = util.dict_to_dot(config)
|
2020-08-28 12:06:23 +00:00
|
|
|
for field in remove_config_values:
|
2020-08-28 11:55:32 +00:00
|
|
|
del config_dot[field]
|
|
|
|
config = util.dot_to_dict(config_dot)
|
2021-04-01 17:36:23 +00:00
|
|
|
run = wandb.init(project=project_name, config=config, reinit=True)
|
2020-10-03 12:57:46 +00:00
|
|
|
console_log_step, console_finalize = console(nlp, stdout, stderr)
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2021-04-01 17:36:23 +00:00
|
|
|
def log_dir_artifact(
|
|
|
|
path: str,
|
|
|
|
name: str,
|
|
|
|
type: str,
|
|
|
|
metadata: Optional[Dict[str, Any]] = {},
|
|
|
|
aliases: Optional[List[str]] = [],
|
|
|
|
):
|
|
|
|
dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
|
|
|
|
dataset_artifact.add_dir(path, name=name)
|
|
|
|
wandb.log_artifact(dataset_artifact, aliases=aliases)
|
|
|
|
|
|
|
|
if log_dataset_dir:
|
|
|
|
log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset")
|
|
|
|
|
2020-10-03 12:57:46 +00:00
|
|
|
def log_step(info: Optional[Dict[str, Any]]):
|
2020-08-26 13:24:33 +00:00
|
|
|
console_log_step(info)
|
2020-10-03 12:57:46 +00:00
|
|
|
if info is not None:
|
|
|
|
score = info["score"]
|
|
|
|
other_scores = info["other_scores"]
|
|
|
|
losses = info["losses"]
|
|
|
|
wandb.log({"score": score})
|
|
|
|
if losses:
|
|
|
|
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
|
|
|
if isinstance(other_scores, dict):
|
|
|
|
wandb.log(other_scores)
|
2021-04-01 17:36:23 +00:00
|
|
|
if model_log_interval and info.get("output_path"):
|
|
|
|
if info["step"] % model_log_interval == 0 and info["step"] != 0:
|
|
|
|
log_dir_artifact(
|
|
|
|
path=info["output_path"],
|
|
|
|
name="pipeline_" + run.id,
|
|
|
|
type="checkpoint",
|
|
|
|
metadata=info,
|
|
|
|
aliases=[
|
|
|
|
f"epoch {info['epoch']} step {info['step']}",
|
|
|
|
"latest",
|
|
|
|
"best"
|
|
|
|
if info["score"] == max(info["checkpoints"])[0]
|
|
|
|
else "",
|
|
|
|
],
|
|
|
|
)
|
2020-08-26 13:24:33 +00:00
|
|
|
|
2020-10-03 14:31:58 +00:00
|
|
|
def finalize() -> None:
|
2020-08-26 13:24:33 +00:00
|
|
|
console_finalize()
|
2020-09-15 10:56:33 +00:00
|
|
|
wandb.join()
|
2020-08-26 13:24:33 +00:00
|
|
|
|
|
|
|
return log_step, finalize
|
|
|
|
|
|
|
|
return setup_logger
|