mirror of https://github.com/explosion/spaCy.git
Pass values for CNN maxout pieces option
This commit is contained in:
parent
b832f89ff8
commit
24e85c2048
|
@ -215,6 +215,7 @@ class TokenVectorEncoder(BaseThincComponent):
|
|||
self.model = model
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
|
||||
|
@ -286,9 +287,7 @@ class TokenVectorEncoder(BaseThincComponent):
|
|||
pipeline (list): The pipeline the model is part of.
|
||||
"""
|
||||
if self.model is True:
|
||||
self.model = self.Model(
|
||||
pretrained_dims=self.vocab.vectors_length,
|
||||
**self.cfg)
|
||||
self.model = self.Model(**self.cfg)
|
||||
|
||||
|
||||
class NeuralTagger(BaseThincComponent):
|
||||
|
@ -297,6 +296,7 @@ class NeuralTagger(BaseThincComponent):
|
|||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
|
||||
def __call__(self, doc):
|
||||
tags = self.predict(([doc], [doc.tensor]))
|
||||
|
@ -442,6 +442,7 @@ class NeuralTagger(BaseThincComponent):
|
|||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
serialize = OrderedDict((
|
||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||
('tag_map', lambda p: p.open('wb').write(msgpack.dumps(
|
||||
|
@ -486,6 +487,7 @@ class NeuralLabeller(NeuralTagger):
|
|||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
|
@ -309,6 +309,7 @@ cdef class Parser:
|
|||
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
|
||||
if 'pretrained_dims' not in cfg:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
self.cfg = cfg
|
||||
if 'actions' in self.cfg:
|
||||
for action, labels in self.cfg.get('actions', {}).items():
|
||||
|
|
Loading…
Reference in New Issue