Fix add_label methods

This commit is contained in:
Matthew Honnibal 2017-11-01 17:06:43 +01:00
parent dad8f09fba
commit 7ae1aacdb8
1 changed files with 12 additions and 10 deletions

View File

@ -441,11 +441,12 @@ class Tagger(Pipe):
def add_label(self, label): def add_label(self, label):
if label in self.labels: if label in self.labels:
return 0 return 0
smaller = self.model[-1]._layers[-1] if self.model not in (True, False, None):
smaller = self.model._layers[-1]
larger = Softmax(len(self.labels)+1, smaller.nI) larger = Softmax(len(self.labels)+1, smaller.nI)
copy_array(larger.W[:smaller.nO], smaller.W) copy_array(larger.W[:smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b) copy_array(larger.b[:smaller.nO], smaller.b)
self.model[-1]._layers[-1] = larger self.model._layers[-1] = larger
self.labels.append(label) self.labels.append(label)
return 1 return 1
@ -759,11 +760,12 @@ class TextCategorizer(Pipe):
def add_label(self, label): def add_label(self, label):
if label in self.labels: if label in self.labels:
return 0 return 0
smaller = self.model[-1]._layers[-1] if self.model not in (None, True, False):
smaller = self.model._layers[-1]
larger = Affine(len(self.labels)+1, smaller.nI) larger = Affine(len(self.labels)+1, smaller.nI)
copy_array(larger.W[:smaller.nO], smaller.W) copy_array(larger.W[:smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b) copy_array(larger.b[:smaller.nO], smaller.b)
self.model[-1]._layers[-1] = larger self.model._layers[-1] = larger
self.labels.append(label) self.labels.append(label)
return 1 return 1