Tidy up, auto-format, types

This commit is contained in:
Ines Montani 2020-10-03 16:31:58 +02:00
parent 3b8f352eda
commit 989a96308f
1 changed files with 15 additions and 19 deletions

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
import wasabi from wasabi import Printer
import tqdm import tqdm
import sys import sys
@ -7,15 +7,16 @@ from ..util import registry
from .. import util from .. import util
from ..errors import Errors from ..errors import Errors
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@registry.loggers("spacy.ConsoleLogger.v1") @registry.loggers("spacy.ConsoleLogger.v1")
def console_logger(progress_bar: bool=False): def console_logger(progress_bar: bool = False):
def setup_printer( def setup_printer(
nlp: "Language", nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
stdout: IO=sys.stdout, ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
stderr: IO=sys.stderr msg = Printer(no_print=True)
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable]:
msg = wasabi.Printer(no_print=True)
# we assume here that only components are enabled that should be trained & logged # we assume here that only components are enabled that should be trained & logged
logged_pipes = nlp.pipe_names logged_pipes = nlp.pipe_names
eval_frequency = nlp.config["training"]["eval_frequency"] eval_frequency = nlp.config["training"]["eval_frequency"]
@ -32,7 +33,7 @@ def console_logger(progress_bar: bool=False):
stdout.write(msg.row(["-" * width for width in table_widths])) stdout.write(msg.row(["-" * width for width in table_widths]))
progress = None progress = None
def log_step(info: Optional[Dict[str, Any]]): def log_step(info: Optional[Dict[str, Any]]) -> None:
nonlocal progress nonlocal progress
if info is None: if info is None:
@ -78,14 +79,11 @@ def console_logger(progress_bar: bool=False):
if progress_bar: if progress_bar:
# Set disable=None, so that it disables on non-TTY # Set disable=None, so that it disables on non-TTY
progress = tqdm.tqdm( progress = tqdm.tqdm(
total=eval_frequency, total=eval_frequency, disable=None, leave=False, file=stderr
disable=None,
leave=False,
file=stderr
) )
progress.set_description(f"Epoch {info['epoch']+1}") progress.set_description(f"Epoch {info['epoch']+1}")
def finalize(): def finalize() -> None:
pass pass
return log_step, finalize return log_step, finalize
@ -100,10 +98,8 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
console = console_logger(progress_bar=False) console = console_logger(progress_bar=False)
def setup_logger( def setup_logger(
nlp: "Language", nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
stdout: IO=sys.stdout, ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
stderr: IO=sys.stderr
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate() config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config) config_dot = util.dict_to_dot(config)
for field in remove_config_values: for field in remove_config_values:
@ -124,7 +120,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
if isinstance(other_scores, dict): if isinstance(other_scores, dict):
wandb.log(other_scores) wandb.log(other_scores)
def finalize(): def finalize() -> None:
console_finalize() console_finalize()
wandb.join() wandb.join()