mirror of https://github.com/explosion/spaCy.git
Fix add_label methods
This commit is contained in:
parent
dad8f09fba
commit
7ae1aacdb8
|
@ -441,11 +441,12 @@ class Tagger(Pipe):
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
smaller = self.model[-1]._layers[-1]
|
if self.model not in (True, False, None):
|
||||||
|
smaller = self.model._layers[-1]
|
||||||
larger = Softmax(len(self.labels)+1, smaller.nI)
|
larger = Softmax(len(self.labels)+1, smaller.nI)
|
||||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||||
self.model[-1]._layers[-1] = larger
|
self.model._layers[-1] = larger
|
||||||
self.labels.append(label)
|
self.labels.append(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -759,11 +760,12 @@ class TextCategorizer(Pipe):
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
smaller = self.model[-1]._layers[-1]
|
if self.model not in (None, True, False):
|
||||||
|
smaller = self.model._layers[-1]
|
||||||
larger = Affine(len(self.labels)+1, smaller.nI)
|
larger = Affine(len(self.labels)+1, smaller.nI)
|
||||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||||
self.model[-1]._layers[-1] = larger
|
self.model._layers[-1] = larger
|
||||||
self.labels.append(label)
|
self.labels.append(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue