mirror of https://github.com/explosion/spaCy.git
prevent loading a pretrained Tok2Vec layer AND pretrained components
This commit is contained in:
parent
04ba37b667
commit
291483157d
|
@ -15,6 +15,7 @@ import random
|
|||
|
||||
from .._ml import create_default_optimizer
|
||||
from ..util import use_gpu as set_gpu
|
||||
from ..errors import Errors
|
||||
from ..gold import GoldCorpus
|
||||
from ..compat import path2str
|
||||
from ..lookups import Lookups
|
||||
|
@ -182,6 +183,7 @@ def train(
|
|||
msg.warn("Unable to activate GPU: {}".format(use_gpu))
|
||||
msg.text("Using CPU only")
|
||||
use_gpu = -1
|
||||
base_components = []
|
||||
if base_model:
|
||||
msg.text("Starting with base model '{}'".format(base_model))
|
||||
nlp = util.load_model(base_model)
|
||||
|
@ -227,6 +229,7 @@ def train(
|
|||
exits=1,
|
||||
)
|
||||
msg.text("Extending component from base model '{}'".format(pipe))
|
||||
base_components.append(pipe)
|
||||
disabled_pipes = nlp.disable_pipes(
|
||||
[p for p in nlp.pipe_names if p not in pipeline]
|
||||
)
|
||||
|
@ -299,7 +302,7 @@ def train(
|
|||
|
||||
# Load in pretrained weights
|
||||
if init_tok2vec is not None:
|
||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec, base_components)
|
||||
msg.text("Loaded pretrained tok2vec for: {}".format(components))
|
||||
|
||||
# Verify textcat config
|
||||
|
@ -642,7 +645,7 @@ def _load_vectors(nlp, vectors):
|
|||
util.load_model(vectors, vocab=nlp.vocab)
|
||||
|
||||
|
||||
def _load_pretrained_tok2vec(nlp, loc):
|
||||
def _load_pretrained_tok2vec(nlp, loc, base_components):
|
||||
"""Load pretrained weights for the 'token-to-vector' part of the component
|
||||
models, which is typically a CNN. See 'spacy pretrain'. Experimental.
|
||||
"""
|
||||
|
@ -651,6 +654,8 @@ def _load_pretrained_tok2vec(nlp, loc):
|
|||
loaded = []
|
||||
for name, component in nlp.pipeline:
|
||||
if hasattr(component, "model") and hasattr(component.model, "tok2vec"):
|
||||
if name in base_components:
|
||||
raise ValueError(Errors.E200.format(component=name))
|
||||
component.tok2vec.from_bytes(weights_data)
|
||||
loaded.append(name)
|
||||
return loaded
|
||||
|
|
|
@ -568,6 +568,8 @@ class Errors(object):
|
|||
E198 = ("Unable to return {n} most similar vectors for the current vectors "
|
||||
"table, which contains {n_rows} vectors.")
|
||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
E200 = ("Specifying a base model with a pretrained component '{component}' "
|
||||
"can not be combined with adding a pretrained Tok2Vec layer.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
Loading…
Reference in New Issue