mirror of https://github.com/explosion/spaCy.git
Tidy up, auto-format, types
This commit is contained in:
parent
3b8f352eda
commit
989a96308f
|
@ -1,5 +1,5 @@
|
|||
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
|
||||
import wasabi
|
||||
from wasabi import Printer
|
||||
import tqdm
|
||||
import sys
|
||||
|
||||
|
@ -7,15 +7,16 @@ from ..util import registry
|
|||
from .. import util
|
||||
from ..errors import Errors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
@registry.loggers("spacy.ConsoleLogger.v1")
|
||||
def console_logger(progress_bar: bool=False):
|
||||
def console_logger(progress_bar: bool = False):
|
||||
def setup_printer(
|
||||
nlp: "Language",
|
||||
stdout: IO=sys.stdout,
|
||||
stderr: IO=sys.stderr
|
||||
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable]:
|
||||
msg = wasabi.Printer(no_print=True)
|
||||
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
|
||||
msg = Printer(no_print=True)
|
||||
# we assume here that only components are enabled that should be trained & logged
|
||||
logged_pipes = nlp.pipe_names
|
||||
eval_frequency = nlp.config["training"]["eval_frequency"]
|
||||
|
@ -32,14 +33,14 @@ def console_logger(progress_bar: bool=False):
|
|||
stdout.write(msg.row(["-" * width for width in table_widths]))
|
||||
progress = None
|
||||
|
||||
def log_step(info: Optional[Dict[str, Any]]):
|
||||
def log_step(info: Optional[Dict[str, Any]]) -> None:
|
||||
nonlocal progress
|
||||
|
||||
if info is None:
|
||||
# If we don't have a new checkpoint, just return.
|
||||
if progress is not None:
|
||||
progress.update(1)
|
||||
return
|
||||
return
|
||||
try:
|
||||
losses = [
|
||||
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
||||
|
@ -78,14 +79,11 @@ def console_logger(progress_bar: bool=False):
|
|||
if progress_bar:
|
||||
# Set disable=None, so that it disables on non-TTY
|
||||
progress = tqdm.tqdm(
|
||||
total=eval_frequency,
|
||||
disable=None,
|
||||
leave=False,
|
||||
file=stderr
|
||||
total=eval_frequency, disable=None, leave=False, file=stderr
|
||||
)
|
||||
progress.set_description(f"Epoch {info['epoch']+1}")
|
||||
|
||||
def finalize():
|
||||
def finalize() -> None:
|
||||
pass
|
||||
|
||||
return log_step, finalize
|
||||
|
@ -100,10 +98,8 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
|||
console = console_logger(progress_bar=False)
|
||||
|
||||
def setup_logger(
|
||||
nlp: "Language",
|
||||
stdout: IO=sys.stdout,
|
||||
stderr: IO=sys.stderr
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
|
||||
config = nlp.config.interpolate()
|
||||
config_dot = util.dict_to_dot(config)
|
||||
for field in remove_config_values:
|
||||
|
@ -124,7 +120,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
|||
if isinstance(other_scores, dict):
|
||||
wandb.log(other_scores)
|
||||
|
||||
def finalize():
|
||||
def finalize() -> None:
|
||||
console_finalize()
|
||||
wandb.join()
|
||||
|
||||
|
|
Loading…
Reference in New Issue