mirror of https://github.com/explosion/spaCy.git
Fix add_label methods
This commit is contained in:
parent
dad8f09fba
commit
7ae1aacdb8
|
@ -441,11 +441,12 @@ class Tagger(Pipe):
|
|||
def add_label(self, label):
|
||||
if label in self.labels:
|
||||
return 0
|
||||
smaller = self.model[-1]._layers[-1]
|
||||
larger = Softmax(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model[-1]._layers[-1] = larger
|
||||
if self.model not in (True, False, None):
|
||||
smaller = self.model._layers[-1]
|
||||
larger = Softmax(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model._layers[-1] = larger
|
||||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
|
@ -759,11 +760,12 @@ class TextCategorizer(Pipe):
|
|||
def add_label(self, label):
|
||||
if label in self.labels:
|
||||
return 0
|
||||
smaller = self.model[-1]._layers[-1]
|
||||
larger = Affine(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model[-1]._layers[-1] = larger
|
||||
if self.model not in (None, True, False):
|
||||
smaller = self.model._layers[-1]
|
||||
larger = Affine(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model._layers[-1] = larger
|
||||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue