From bfa8e11ffa58de91393d1ba398c12f7a600b8d48 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 10 Jul 2020 20:52:00 +0200 Subject: [PATCH] Update and auto-format --- spacy/cli/debug_model.py | 27 +++++++++++++-------------- spacy/schemas.py | 2 +- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/spacy/cli/debug_model.py b/spacy/cli/debug_model.py index 54c71f824..2e315cae5 100644 --- a/spacy/cli/debug_model.py +++ b/spacy/cli/debug_model.py @@ -1,10 +1,9 @@ -from typing import List from pathlib import Path from wasabi import msg - -from ._app import app, Arg, Opt -from .. import util from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam + +from ._util import app, Arg, Opt +from .. import util from ..lang.en import English @@ -50,16 +49,11 @@ def debug_model_cli( msg.info(f"Using CPU") debug_model( - config_path, - print_settings=print_settings, + config_path, print_settings=print_settings, ) -def debug_model( - config_path: Path, - *, - print_settings=None -): +def debug_model(config_path: Path, *, print_settings=None): if print_settings is None: print_settings = {} @@ -83,7 +77,7 @@ def debug_model( for e in range(3): Y, get_dX = model.begin_update(_get_docs()) dY = get_gradient(model, Y) - _ = get_dX(dY) + get_dX(dY) model.finish_update(optimizer) if print_settings.get("print_after_training"): msg.info(f"After training:") @@ -115,7 +109,12 @@ def _get_docs(): def _get_output(xp): - return xp.asarray([xp.asarray([i+10, i+20, i+30], dtype="float32") for i, _ in enumerate(_get_docs())]) + return xp.asarray( + [ + xp.asarray([i + 10, i + 20, i + 30], dtype="float32") + for i, _ in enumerate(_get_docs()) + ] + ) def _print_model(model, print_settings): @@ -161,7 +160,7 @@ def _print_matrix(value): return value result = str(value.shape) + " - sample: " sample_matrix = value - for d in range(value.ndim-1): + for d in range(value.ndim - 1): sample_matrix = sample_matrix[0] sample_matrix = sample_matrix[0:5] result = result + str(sample_matrix) diff --git a/spacy/schemas.py b/spacy/schemas.py index c4d67e90f..e0776c56e 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -201,7 +201,7 @@ class ConfigSchemaTraining(BaseModel): max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for") max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for") eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)") - seed: StrictInt = Field(..., title="Random seed") + seed: Optional[StrictInt] = Field(..., title="Random seed") accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps") use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch") use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")