prevent loss keyerror for non-trainable components

This commit is contained in:
svlandeg 2020-10-05 16:33:28 +02:00
parent 65abd77779
commit dc06912c76
2 changed files with 7 additions and 15 deletions

View File

@ -41,19 +41,10 @@ def console_logger(progress_bar: bool = False):
if progress is not None:
progress.update(1)
return
try:
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in logged_pipes
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (losses)",
key=str(e),
keys=list(info["losses"].keys()),
)
) from None
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in logged_pipes if pipe_name in info["losses"]
]
scores = []
for col in score_cols:

View File

@ -184,7 +184,7 @@ def train_while_improving(
and hasattr(proc, "model")
and proc.model not in (True, False, None)
):
proc.model.finish_update(optimizer)
proc.finish_update(optimizer)
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
@ -287,7 +287,8 @@ def update_meta(
if metric is not None:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
if pipe_name in info["losses"]:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
def create_before_to_disk_callback(