mirror of https://github.com/explosion/spaCy.git
fix convolution layer
This commit is contained in:
parent
dd691d0053
commit
7edb2e1711
|
@ -12,9 +12,9 @@ from examples.pipeline.wiki_entity_linking import run_el, training_set_creator,
|
||||||
|
|
||||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic
|
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic
|
||||||
|
|
||||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone
|
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
|
||||||
from thinc.v2v import Model, Maxout, Affine
|
from thinc.v2v import Model, Maxout, Affine
|
||||||
from thinc.t2v import Pooling, mean_pool
|
from thinc.t2v import Pooling, mean_pool, sum_pool
|
||||||
from thinc.t2t import ParametricAttention
|
from thinc.t2t import ParametricAttention
|
||||||
from thinc.misc import Residual
|
from thinc.misc import Residual
|
||||||
from thinc.misc import LayerNorm as LN
|
from thinc.misc import LayerNorm as LN
|
||||||
|
@ -96,13 +96,13 @@ class EL_Model:
|
||||||
try:
|
try:
|
||||||
# if to_print:
|
# if to_print:
|
||||||
# print()
|
# print()
|
||||||
# print(article_count, "Training on article", article_id)
|
print(article_count, "Training on article", article_id)
|
||||||
article_count += 1
|
article_count += 1
|
||||||
article_docs = list()
|
article_docs = list()
|
||||||
entities = list()
|
entities = list()
|
||||||
golds = list()
|
golds = list()
|
||||||
for inst_cluster in inst_cluster_set:
|
for inst_cluster in inst_cluster_set:
|
||||||
if instance_pos_count < 2: # TODO remove
|
if instance_pos_count < 2: # TODO del
|
||||||
article_docs.append(train_doc[article_id])
|
article_docs.append(train_doc[article_id])
|
||||||
entities.append(train_pos.get(inst_cluster))
|
entities.append(train_pos.get(inst_cluster))
|
||||||
golds.append(float(1.0))
|
golds.append(float(1.0))
|
||||||
|
@ -228,16 +228,23 @@ class EL_Model:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _encoder(in_width, hidden_width):
|
def _encoder(in_width, hidden_width):
|
||||||
|
conv_depth = 1
|
||||||
|
cnn_maxout_pieces = 3
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
|
convolution = Residual((ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3, pieces=cnn_maxout_pieces))))
|
||||||
|
|
||||||
encoder = SpacyVectors \
|
encoder = SpacyVectors \
|
||||||
|
>> with_flatten(LN(Maxout(in_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
|
||||||
>> flatten_add_lengths \
|
>> flatten_add_lengths \
|
||||||
>> ParametricAttention(in_width)\
|
>> ParametricAttention(in_width)\
|
||||||
>> Pooling(mean_pool) \
|
>> Pooling(mean_pool) \
|
||||||
>> (ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3))) \
|
>> Residual(zero_init(Maxout(in_width, in_width))) \
|
||||||
>> zero_init(Affine(hidden_width, in_width, drop_factor=0.0))
|
>> zero_init(Affine(hidden_width, in_width, drop_factor=0.0))
|
||||||
|
|
||||||
# TODO: ReLu instead of LN(Maxout) ?
|
# TODO: ReLu instead of LN(Maxout) ?
|
||||||
# TODO: more convolutions ?
|
# TODO: more convolutions ?
|
||||||
|
# sum_pool or mean_pool ?
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -261,16 +268,17 @@ class EL_Model:
|
||||||
print(doc_encoding)
|
print(doc_encoding)
|
||||||
|
|
||||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
||||||
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
|
||||||
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
|
||||||
|
|
||||||
print("doc_encodings", len(doc_encodings), doc_encodings)
|
print("doc_encodings", len(doc_encodings), doc_encodings)
|
||||||
|
|
||||||
|
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||||
print("entity_encodings", len(entity_encodings), entity_encodings)
|
print("entity_encodings", len(entity_encodings), entity_encodings)
|
||||||
print("concat_encodings", len(concat_encodings), concat_encodings)
|
|
||||||
|
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
||||||
|
# print("concat_encodings", len(concat_encodings), concat_encodings)
|
||||||
|
|
||||||
predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||||
print("predictions", predictions)
|
|
||||||
predictions = self.model.ops.flatten(predictions)
|
predictions = self.model.ops.flatten(predictions)
|
||||||
|
print("predictions", predictions)
|
||||||
golds = self.model.ops.asarray(golds)
|
golds = self.model.ops.asarray(golds)
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(predictions, golds)
|
loss, d_scores = self.get_loss(predictions, golds)
|
||||||
|
@ -287,15 +295,15 @@ class EL_Model:
|
||||||
|
|
||||||
d_scores = d_scores.reshape((-1, 1))
|
d_scores = d_scores.reshape((-1, 1))
|
||||||
d_scores = d_scores.astype(np.float32)
|
d_scores = d_scores.astype(np.float32)
|
||||||
print("d_scores", d_scores)
|
# print("d_scores", d_scores)
|
||||||
|
|
||||||
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
||||||
print("model_gradient", model_gradient)
|
# print("model_gradient", model_gradient)
|
||||||
|
|
||||||
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
|
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
|
||||||
print("doc_gradient", doc_gradient)
|
# print("doc_gradient", doc_gradient)
|
||||||
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
|
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
|
||||||
print("entity_gradient", entity_gradient)
|
# print("entity_gradient", entity_gradient)
|
||||||
|
|
||||||
bp_doc(doc_gradient)
|
bp_doc(doc_gradient)
|
||||||
bp_encoding(entity_gradient)
|
bp_encoding(entity_gradient)
|
||||||
|
|
Loading…
Reference in New Issue