diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index f4a654591..7a41085e4 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -866,8 +866,8 @@ class TextCategorizer(Pipe): name = 'textcat' @classmethod - def Model(cls, nr_class=1, width=64, **cfg): - return build_text_classifier(nr_class, width, **cfg) + def Model(cls, **cfg): + return build_text_classifier(**cfg) def __init__(self, vocab, model=True, **cfg): self.vocab = vocab @@ -948,8 +948,9 @@ class TextCategorizer(Pipe): token_vector_width = 64 if self.model is True: self.cfg['pretrained_dims'] = self.vocab.vectors_length - self.model = self.Model(len(self.labels), token_vector_width, - **self.cfg) + self.cfg['nr_class'] = len(self.labels) + self.cfg['width'] = token_vector_width + self.model = self.Model(**self.cfg) link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer()