mirror of https://github.com/explosion/spaCy.git
The Parser is now a Pipe (2) (#5844)
* moving syntax folder to _parser_internals * moving nn_parser and transition_system * move nn_parser and transition_system out of internals folder * moving nn_parser code into transition_system file * rename transition_system to transition_parser * moving parser_model and _state to ml * move _state back to internals * The Parser now inherits from Pipe! * small code fixes * removing unnecessary imports * remove link_vectors_to_models * transition_system to internals folder * little bit more cleanup * newlines
This commit is contained in:
parent
3449c45fd9
commit
ca491722ad
|
@ -16,7 +16,7 @@ from bin.ud import conll17_ud_eval
|
|||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import Example
|
||||
from spacy.util import compounding, minibatch, minibatch_by_words
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.matcher import Matcher
|
||||
from spacy import displacy
|
||||
from collections import defaultdict
|
||||
|
|
|
@ -13,7 +13,7 @@ import spacy
|
|||
import spacy.util
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import Example
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from collections import defaultdict
|
||||
from spacy.matcher import Matcher
|
||||
|
||||
|
|
16
setup.py
16
setup.py
|
@ -31,6 +31,7 @@ MOD_NAMES = [
|
|||
"spacy.vocab",
|
||||
"spacy.attrs",
|
||||
"spacy.kb",
|
||||
"spacy.ml.parser_model",
|
||||
"spacy.morphology",
|
||||
"spacy.pipeline.dep_parser",
|
||||
"spacy.pipeline.morphologizer",
|
||||
|
@ -40,14 +41,14 @@ MOD_NAMES = [
|
|||
"spacy.pipeline.sentencizer",
|
||||
"spacy.pipeline.senter",
|
||||
"spacy.pipeline.tagger",
|
||||
"spacy.syntax.stateclass",
|
||||
"spacy.syntax._state",
|
||||
"spacy.pipeline.transition_parser",
|
||||
"spacy.pipeline._parser_internals.arc_eager",
|
||||
"spacy.pipeline._parser_internals.ner",
|
||||
"spacy.pipeline._parser_internals.nonproj",
|
||||
"spacy.pipeline._parser_internals._state",
|
||||
"spacy.pipeline._parser_internals.stateclass",
|
||||
"spacy.pipeline._parser_internals.transition_system",
|
||||
"spacy.tokenizer",
|
||||
"spacy.syntax.nn_parser",
|
||||
"spacy.syntax._parser_model",
|
||||
"spacy.syntax.nonproj",
|
||||
"spacy.syntax.transition_system",
|
||||
"spacy.syntax.arc_eager",
|
||||
"spacy.gold.gold_io",
|
||||
"spacy.tokens.doc",
|
||||
"spacy.tokens.span",
|
||||
|
@ -57,7 +58,6 @@ MOD_NAMES = [
|
|||
"spacy.matcher.matcher",
|
||||
"spacy.matcher.phrasematcher",
|
||||
"spacy.matcher.dependencymatcher",
|
||||
"spacy.syntax.ner",
|
||||
"spacy.symbols",
|
||||
"spacy.vectors",
|
||||
]
|
||||
|
|
|
@ -10,7 +10,7 @@ from thinc.api import Config
|
|||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..gold import Corpus, Example
|
||||
from ..syntax import nonproj
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..language import Language
|
||||
from .. import util
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from .align import Alignment
|
|||
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
||||
from .iob_utils import spans_from_biluo_tags
|
||||
from ..errors import Errors, Warnings
|
||||
from ..syntax import nonproj
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
|
||||
|
||||
cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
|
||||
from ._state cimport StateC
|
||||
from ..typedefs cimport weight_t, hash_t
|
||||
from ..pipeline._parser_internals._state cimport StateC
|
||||
|
||||
|
||||
cdef struct SizesC:
|
|
@ -1,29 +1,18 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
cimport cython.parallel
|
||||
cimport numpy as np
|
||||
from libc.math cimport exp
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.extra.search cimport Beam
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
cimport blis.cy
|
||||
|
||||
import numpy
|
||||
import numpy.random
|
||||
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
|
||||
from thinc.api import Model, CupyOps, NumpyOps
|
||||
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ..tokens.doc cimport Doc
|
||||
from .stateclass cimport StateClass
|
||||
from .transition_system cimport Transition
|
||||
|
||||
from ..compat import copy_array
|
||||
from ..errors import Errors, TempErrors
|
||||
from ..util import create_default_optimizer
|
||||
from .. import util
|
||||
from . import nonproj
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ..pipeline._parser_internals.stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef WeightsC get_c_weights(model) except *:
|
|
@ -1,5 +1,5 @@
|
|||
from thinc.api import Model, noop, use_ops, Linear
|
||||
from ..syntax._parser_model import ParserStepModel
|
||||
from .parser_model import ParserStepModel
|
||||
|
||||
|
||||
def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()):
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from libc.string cimport memcpy, memset, memmove
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..symbols cimport punct
|
||||
from ..attrs cimport IS_SPACE
|
||||
from ..typedefs cimport attr_t
|
||||
from ...vocab cimport EMPTY_LEXEME
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...typedefs cimport attr_t
|
||||
|
||||
|
||||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
|
@ -1,8 +1,6 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ..typedefs cimport weight_t, attr_t
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from .transition_system cimport Transition, TransitionSystem
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
|
@ -1,24 +1,17 @@
|
|||
# cython: profile=True, cdivision=True, infer_types=True
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
import json
|
||||
|
||||
from ..typedefs cimport hash_t, attr_t
|
||||
from ..strings cimport hash_string
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...typedefs cimport hash_t, attr_t
|
||||
from ...strings cimport hash_string
|
||||
from ...structs cimport TokenC
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...gold.example cimport Example
|
||||
from ...errors import Errors
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||
from ..gold.example cimport Example
|
||||
|
||||
from ..errors import Errors
|
||||
from .nonproj import is_nonproj_tree
|
||||
from . import nonproj
|
||||
|
||||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
|
@ -1,6 +1,4 @@
|
|||
from .transition_system cimport TransitionSystem
|
||||
from .transition_system cimport Transition
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
|
||||
cdef class BiluoPushDown(TransitionSystem):
|
|
@ -2,17 +2,14 @@ from collections import Counter
|
|||
from libc.stdint cimport int32_t
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..typedefs cimport weight_t
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...gold.example cimport Example
|
||||
from ...errors import Errors
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport Transition
|
||||
from .transition_system cimport do_func_t
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..attrs cimport IS_SPACE
|
||||
from ..gold.iob_utils import biluo_tags_from_offsets
|
||||
from ..gold.example cimport Example
|
||||
|
||||
from ..errors import Errors
|
||||
from .transition_system cimport Transition, do_func_t
|
||||
|
||||
|
||||
cdef enum:
|
|
@ -5,9 +5,9 @@ scheme.
|
|||
"""
|
||||
from copy import copy
|
||||
|
||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
|
||||
from ..errors import Errors
|
||||
from ...errors import Errors
|
||||
|
||||
|
||||
DELIMITER = '||'
|
|
@ -1,12 +1,8 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
cimport cython
|
||||
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..typedefs cimport attr_t
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...typedefs cimport attr_t
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ._state cimport StateC
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
# cython: infer_types=True
|
||||
import numpy
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ...tokens.doc cimport Doc
|
||||
|
||||
|
||||
cdef class StateClass:
|
|
@ -1,11 +1,11 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..typedefs cimport attr_t, weight_t
|
||||
from ..structs cimport TokenC
|
||||
from ..strings cimport StringStore
|
||||
from ...typedefs cimport attr_t, weight_t
|
||||
from ...structs cimport TokenC
|
||||
from ...strings cimport StringStore
|
||||
from ...gold.example cimport Example
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from ..gold.example cimport Example
|
||||
|
||||
|
||||
cdef struct Transition:
|
|
@ -1,19 +1,17 @@
|
|||
# cython: infer_types=True
|
||||
from __future__ import print_function
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from collections import Counter
|
||||
import srsly
|
||||
|
||||
from ..typedefs cimport weight_t
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...tokens.doc cimport Doc
|
||||
from ...structs cimport TokenC
|
||||
from .stateclass cimport StateClass
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ..errors import Errors
|
||||
from .. import util
|
||||
from ...errors import Errors
|
||||
from ... import util
|
||||
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
|
@ -1,13 +1,13 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional, Iterable
|
||||
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
|
||||
from thinc.api import Model, Config
|
||||
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.arc_eager cimport ArcEager
|
||||
from .transition_parser cimport Parser
|
||||
from ._parser_internals.arc_eager cimport ArcEager
|
||||
|
||||
from .functions import merge_subtokens
|
||||
from ..language import Language
|
||||
from ..syntax import nonproj
|
||||
from ._parser_internals import nonproj
|
||||
from ..scorer import Scorer
|
||||
|
||||
|
||||
|
|
|
@ -222,9 +222,9 @@ class EntityLinker(Pipe):
|
|||
set_dropout_rate(self.model, drop)
|
||||
if not sentence_docs:
|
||||
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
||||
return 0.0
|
||||
return losses
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(
|
||||
loss, d_scores = self.get_loss(
|
||||
sentence_encodings=sentence_encodings, examples=examples
|
||||
)
|
||||
bp_context(d_scores)
|
||||
|
@ -235,7 +235,7 @@ class EntityLinker(Pipe):
|
|||
self.set_annotations(docs, predictions)
|
||||
return losses
|
||||
|
||||
def get_similarity_loss(self, examples: Iterable[Example], sentence_encodings):
|
||||
def get_loss(self, examples: Iterable[Example], sentence_encodings):
|
||||
entity_encodings = []
|
||||
for eg in examples:
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
|
@ -247,7 +247,7 @@ class EntityLinker(Pipe):
|
|||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
if sentence_encodings.shape != entity_encodings.shape:
|
||||
err = Errors.E147.format(
|
||||
method="get_similarity_loss", msg="gold entities do not match up"
|
||||
method="get_loss", msg="gold entities do not match up"
|
||||
)
|
||||
raise RuntimeError(err)
|
||||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||
|
@ -337,13 +337,13 @@ class EntityLinker(Pipe):
|
|||
final_kb_ids.append(candidates[0].entity_)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
# set all prior probabilities to 0 if incl_prior=False
|
||||
prior_probs = xp.asarray(
|
||||
[c.prior_prob for c in candidates]
|
||||
)
|
||||
if not self.cfg.get("incl_prior"):
|
||||
prior_probs = xp.asarray(
|
||||
[0.0 for c in candidates]
|
||||
[0.0 for _ in candidates]
|
||||
)
|
||||
scores = prior_probs
|
||||
# add in similarity from the context
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional
|
||||
import numpy
|
||||
from thinc.api import CosineDistance, to_categorical, to_categorical, Model, Config
|
||||
from thinc.api import CosineDistance, to_categorical, Model, Config
|
||||
from thinc.api import set_dropout_rate
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
@ -9,7 +9,7 @@ from ..tokens.doc cimport Doc
|
|||
from .pipe import Pipe
|
||||
from .tagger import Tagger
|
||||
from ..language import Language
|
||||
from ..syntax import nonproj
|
||||
from ._parser_internals import nonproj
|
||||
from ..attrs import POS, ID
|
||||
from ..errors import Errors
|
||||
|
||||
|
@ -219,3 +219,6 @@ class ClozeMultitask(Pipe):
|
|||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional, Iterable
|
||||
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
|
||||
from thinc.api import Model, Config
|
||||
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.ner cimport BiluoPushDown
|
||||
from .transition_parser cimport Parser
|
||||
from ._parser_internals.ner cimport BiluoPushDown
|
||||
|
||||
from ..language import Language
|
||||
from ..scorer import Scorer
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
cdef class Pipe:
|
||||
cdef public str name
|
|
@ -8,7 +8,7 @@ from ..errors import Errors
|
|||
from .. import util
|
||||
|
||||
|
||||
class Pipe:
|
||||
cdef class Pipe:
|
||||
"""This class is a base class and not instantiated directly. Trainable
|
||||
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
||||
from it and it defines the interface that components should follow to
|
||||
|
@ -17,8 +17,6 @@ class Pipe:
|
|||
DOCS: https://spacy.io/api/pipe
|
||||
"""
|
||||
|
||||
name = None
|
||||
|
||||
def __init__(self, vocab, model, name, **cfg):
|
||||
"""Initialize a pipeline component.
|
||||
|
||||
|
|
|
@ -203,3 +203,9 @@ class Sentencizer(Pipe):
|
|||
cfg = srsly.read_json(path)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
return self
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -109,7 +109,7 @@ class SentenceRecognizer(Tagger):
|
|||
for eg in examples:
|
||||
eg_truth = []
|
||||
for x in eg.get_aligned("sent_start"):
|
||||
if x == None:
|
||||
if x is None:
|
||||
eg_truth.append(None)
|
||||
elif x == 1:
|
||||
eg_truth.append(labels[1])
|
||||
|
|
|
@ -131,8 +131,6 @@ class SimpleNER(Pipe):
|
|||
return losses
|
||||
|
||||
def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]:
|
||||
loss = 0
|
||||
d_scores = []
|
||||
truths = []
|
||||
for eg in examples:
|
||||
tags = eg.get_aligned("TAG", as_string=True)
|
||||
|
@ -159,7 +157,6 @@ class SimpleNER(Pipe):
|
|||
if not hasattr(get_examples, "__call__"):
|
||||
gold_tuples = get_examples
|
||||
get_examples = lambda: gold_tuples
|
||||
labels = _get_labels(get_examples())
|
||||
for label in _get_labels(get_examples()):
|
||||
self.add_label(label)
|
||||
labels = self.labels
|
||||
|
|
|
@ -238,8 +238,11 @@ class TextCategorizer(Pipe):
|
|||
|
||||
DOCS: https://spacy.io/api/textcategorizer#rehearse
|
||||
"""
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
return losses
|
||||
try:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
except AttributeError:
|
||||
|
@ -250,7 +253,7 @@ class TextCategorizer(Pipe):
|
|||
raise TypeError(err)
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
target = self._rehearsal_model(examples)
|
||||
|
@ -259,7 +262,6 @@ class TextCategorizer(Pipe):
|
|||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
return losses
|
||||
|
||||
|
|
|
@ -199,6 +199,9 @@ class Tok2Vec(Pipe):
|
|||
docs = [Doc(self.vocab, words=["hello"])]
|
||||
self.model.initialize(X=docs)
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Tok2VecListener(Model):
|
||||
"""A layer that gets fed its answers from an upstream connection,
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from .stateclass cimport StateClass
|
||||
from .arc_eager cimport TransitionSystem
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from ._state cimport StateC
|
||||
from ._parser_model cimport WeightsC, ActivationsC, SizesC
|
||||
from .pipe cimport Pipe
|
||||
from ._parser_internals.transition_system cimport Transition, TransitionSystem
|
||||
from ._parser_internals._state cimport StateC
|
||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef class Parser(Pipe):
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object model
|
||||
cdef public str name
|
||||
cdef public object _rehearsal_model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef readonly object cfg
|
|
@ -1,42 +1,32 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
cimport cython.parallel
|
||||
from __future__ import print_function
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
from itertools import islice
|
||||
from cpython.ref cimport PyObject, Py_XDECREF
|
||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from libc.math cimport exp
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.string cimport memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
|
||||
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
|
||||
from thinc.api import get_array_module, zero_init, set_dropout_rate
|
||||
from itertools import islice
|
||||
import srsly
|
||||
|
||||
from ._parser_internals.stateclass cimport StateClass
|
||||
from ..ml.parser_model cimport alloc_activations, free_activations
|
||||
from ..ml.parser_model cimport predict_states, arg_max_if_valid
|
||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
||||
from ..ml.parser_model cimport get_c_weights, get_c_sizes
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
from ..util import create_default_optimizer
|
||||
|
||||
from thinc.api import set_dropout_rate
|
||||
import numpy.random
|
||||
import numpy
|
||||
import warnings
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ._parser_model cimport alloc_activations, free_activations
|
||||
from ._parser_model cimport predict_states, arg_max_if_valid
|
||||
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
||||
from ._parser_model cimport get_c_weights, get_c_sizes
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport Transition
|
||||
|
||||
from ..util import create_default_optimizer, registry
|
||||
from ..compat import copy_array
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
from . import nonproj
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef class Parser(Pipe):
|
||||
"""
|
||||
Base class of the DependencyParser and EntityRecognizer.
|
||||
"""
|
||||
|
@ -107,7 +97,7 @@ cdef class Parser:
|
|||
|
||||
@property
|
||||
def tok2vec(self):
|
||||
'''Return the embedding and convolutional layer of the model.'''
|
||||
"""Return the embedding and convolutional layer of the model."""
|
||||
return self.model.get_ref("tok2vec")
|
||||
|
||||
@property
|
||||
|
@ -138,13 +128,13 @@ cdef class Parser:
|
|||
raise NotImplementedError
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||
'''Setup models for secondary objectives, to benefit from multi-task
|
||||
"""Setup models for secondary objectives, to benefit from multi-task
|
||||
learning. This method is intended to be overridden by subclasses.
|
||||
|
||||
For instance, the dependency parser can benefit from sharing
|
||||
an input representation with a label prediction model. These auxiliary
|
||||
models are discarded after training.
|
||||
'''
|
||||
"""
|
||||
pass
|
||||
|
||||
def use_params(self, params):
|
|
@ -4,8 +4,8 @@ from spacy import registry
|
|||
from spacy.gold import Example
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.tokens import Doc
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from spacy.lang.en import English
|
|||
|
||||
from spacy.language import Language
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.syntax.ner import BiluoPushDown
|
||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||
from spacy.gold import Example
|
||||
from spacy.tokens import Doc
|
||||
from spacy.vocab import Vocab
|
||||
|
|
|
@ -3,8 +3,8 @@ import pytest
|
|||
from spacy import registry
|
||||
from spacy.gold import Example
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.syntax.nn_parser import Parser
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.pipeline.transition_parser import Parser
|
||||
from spacy.tokens.doc import Doc
|
||||
from thinc.api import Model
|
||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc
|
||||
from spacy.syntax.nonproj import is_nonproj_tree
|
||||
from spacy.syntax import nonproj
|
||||
from spacy.pipeline._parser_internals.nonproj import ancestors, contains_cycle, is_nonproj_arc
|
||||
from spacy.pipeline._parser_internals.nonproj import is_nonproj_tree
|
||||
from spacy.pipeline._parser_internals import nonproj
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
|
Loading…
Reference in New Issue