mirror of https://github.com/explosion/spaCy.git
use data_validation context manager
This commit is contained in:
parent
5fa3235d06
commit
cc2f58a1b0
|
@ -2,7 +2,7 @@ from typing import Dict, Any, Optional
|
|||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config
|
||||
from thinc.api import Model, DATA_VALIDATION
|
||||
from thinc.api import Model, data_validation
|
||||
import typer
|
||||
|
||||
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
|
||||
|
@ -90,9 +90,9 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
|||
# STEP 1: Initializing the model and printing again
|
||||
Y = _get_output(model.ops.xp)
|
||||
_set_output_dim(nO=Y.shape[-1], model=model)
|
||||
DATA_VALIDATION.set(False) # The output vector might differ from the official type of the output layer
|
||||
model.initialize(X=_get_docs(), Y=Y)
|
||||
DATA_VALIDATION.set(True)
|
||||
# The output vector might differ from the official type of the output layer
|
||||
with data_validation(False):
|
||||
model.initialize(X=_get_docs(), Y=Y)
|
||||
if print_settings.get("print_after_init"):
|
||||
msg.info(f"After initialization:")
|
||||
_print_model(model, print_settings)
|
||||
|
|
Loading…
Reference in New Issue