mirror of https://github.com/explosion/spaCy.git
Tidy up, auto-format, types
This commit is contained in:
parent
3b8f352eda
commit
989a96308f
|
@ -1,5 +1,5 @@
|
||||||
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
|
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
|
||||||
import wasabi
|
from wasabi import Printer
|
||||||
import tqdm
|
import tqdm
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -7,15 +7,16 @@ from ..util import registry
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..language import Language # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@registry.loggers("spacy.ConsoleLogger.v1")
|
@registry.loggers("spacy.ConsoleLogger.v1")
|
||||||
def console_logger(progress_bar: bool=False):
|
def console_logger(progress_bar: bool = False):
|
||||||
def setup_printer(
|
def setup_printer(
|
||||||
nlp: "Language",
|
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||||
stdout: IO=sys.stdout,
|
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
|
||||||
stderr: IO=sys.stderr
|
msg = Printer(no_print=True)
|
||||||
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable]:
|
|
||||||
msg = wasabi.Printer(no_print=True)
|
|
||||||
# we assume here that only components are enabled that should be trained & logged
|
# we assume here that only components are enabled that should be trained & logged
|
||||||
logged_pipes = nlp.pipe_names
|
logged_pipes = nlp.pipe_names
|
||||||
eval_frequency = nlp.config["training"]["eval_frequency"]
|
eval_frequency = nlp.config["training"]["eval_frequency"]
|
||||||
|
@ -32,7 +33,7 @@ def console_logger(progress_bar: bool=False):
|
||||||
stdout.write(msg.row(["-" * width for width in table_widths]))
|
stdout.write(msg.row(["-" * width for width in table_widths]))
|
||||||
progress = None
|
progress = None
|
||||||
|
|
||||||
def log_step(info: Optional[Dict[str, Any]]):
|
def log_step(info: Optional[Dict[str, Any]]) -> None:
|
||||||
nonlocal progress
|
nonlocal progress
|
||||||
|
|
||||||
if info is None:
|
if info is None:
|
||||||
|
@ -78,14 +79,11 @@ def console_logger(progress_bar: bool=False):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
# Set disable=None, so that it disables on non-TTY
|
# Set disable=None, so that it disables on non-TTY
|
||||||
progress = tqdm.tqdm(
|
progress = tqdm.tqdm(
|
||||||
total=eval_frequency,
|
total=eval_frequency, disable=None, leave=False, file=stderr
|
||||||
disable=None,
|
|
||||||
leave=False,
|
|
||||||
file=stderr
|
|
||||||
)
|
)
|
||||||
progress.set_description(f"Epoch {info['epoch']+1}")
|
progress.set_description(f"Epoch {info['epoch']+1}")
|
||||||
|
|
||||||
def finalize():
|
def finalize() -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return log_step, finalize
|
return log_step, finalize
|
||||||
|
@ -100,10 +98,8 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
||||||
console = console_logger(progress_bar=False)
|
console = console_logger(progress_bar=False)
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
nlp: "Language",
|
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||||
stdout: IO=sys.stdout,
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
|
||||||
stderr: IO=sys.stderr
|
|
||||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
config_dot = util.dict_to_dot(config)
|
config_dot = util.dict_to_dot(config)
|
||||||
for field in remove_config_values:
|
for field in remove_config_values:
|
||||||
|
@ -124,7 +120,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
||||||
if isinstance(other_scores, dict):
|
if isinstance(other_scores, dict):
|
||||||
wandb.log(other_scores)
|
wandb.log(other_scores)
|
||||||
|
|
||||||
def finalize():
|
def finalize() -> None:
|
||||||
console_finalize()
|
console_finalize()
|
||||||
wandb.join()
|
wandb.join()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue