mirror of https://github.com/explosion/spaCy.git
Fix x keras deep learning example
This commit is contained in:
parent
19501f3340
commit
80aa4e114b
|
@ -12,17 +12,23 @@ from spacy_hook import create_similarity_pipeline
|
||||||
|
|
||||||
from keras_decomposable_attention import build_model
|
from keras_decomposable_attention import build_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cPickle as pickle
|
||||||
|
except ImportError:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
def train(model_dir, train_loc, dev_loc, shape, settings):
|
def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||||
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
||||||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
||||||
|
|
||||||
print("Loading spaCy")
|
print("Loading spaCy")
|
||||||
nlp = spacy.load('en')
|
nlp = spacy.load('en')
|
||||||
|
assert nlp.path is not None
|
||||||
print("Compiling network")
|
print("Compiling network")
|
||||||
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
||||||
print("Processing texts...")
|
print("Processing texts...")
|
||||||
Xs = []
|
Xs = []
|
||||||
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
|
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
|
||||||
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
|
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
|
||||||
max_length=shape[0],
|
max_length=shape[0],
|
||||||
|
@ -36,35 +42,41 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||||
validation_data=([dev_X1, dev_X2], dev_labels),
|
validation_data=([dev_X1, dev_X2], dev_labels),
|
||||||
nb_epoch=settings['nr_epoch'],
|
nb_epoch=settings['nr_epoch'],
|
||||||
batch_size=settings['batch_size'])
|
batch_size=settings['batch_size'])
|
||||||
|
if not (nlp.path / 'similarity').exists():
|
||||||
|
(nlp.path / 'similarity').mkdir()
|
||||||
|
print("Saving to", model_dir / 'similarity')
|
||||||
|
weights = model.get_weights()
|
||||||
|
with (nlp.path / 'similarity' / 'model').open('wb') as file_:
|
||||||
|
pickle.dump(weights[1:], file_)
|
||||||
|
with (nlp.path / 'similarity' / 'config.json').open('wb') as file_:
|
||||||
|
file_.write(model.to_json())
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model_dir, dev_loc):
|
def evaluate(model_dir, dev_loc):
|
||||||
nlp = spacy.load('en', path=model_dir,
|
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
||||||
tagger=False, parser=False, entity=False, matcher=False,
|
nlp = spacy.load('en',
|
||||||
create_pipeline=create_similarity_pipeline)
|
create_pipeline=create_similarity_pipeline)
|
||||||
n = 0
|
total = 0.
|
||||||
correct = 0
|
correct = 0.
|
||||||
for (text1, text2), label in zip(dev_texts, dev_labels):
|
for text1, text2, label in zip(dev_texts1, dev_texts2, dev_labels):
|
||||||
doc1 = nlp(text1)
|
doc1 = nlp(text1)
|
||||||
doc2 = nlp(text2)
|
doc2 = nlp(text2)
|
||||||
sim = doc1.similarity(doc2)
|
sim = doc1.similarity(doc2)
|
||||||
if bool(sim >= 0.5) == label:
|
if sim.argmax() == label.argmax():
|
||||||
correct += 1
|
correct += 1
|
||||||
n += 1
|
total += 1
|
||||||
return correct, total
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
def demo(model_dir):
|
def demo(model_dir):
|
||||||
nlp = spacy.load('en', path=model_dir,
|
nlp = spacy.load('en', path=model_dir,
|
||||||
tagger=False, parser=False, entity=False, matcher=False,
|
|
||||||
create_pipeline=create_similarity_pipeline)
|
create_pipeline=create_similarity_pipeline)
|
||||||
doc1 = nlp(u'Worst fries ever! Greasy and horrible...')
|
doc1 = nlp(u'What were the best crime fiction books in 2016?')
|
||||||
doc2 = nlp(u'The milkshakes are good. The fries are bad.')
|
doc2 = nlp(
|
||||||
print('doc1.similarity(doc2)', doc1.similarity(doc2))
|
u'What should I read that was published last year? I like crime stories.')
|
||||||
sent1a, sent1b = doc1.sents
|
print(doc1)
|
||||||
print('sent1a.similarity(sent1b)', sent1a.similarity(sent1b))
|
print(doc2)
|
||||||
print('sent1a.similarity(doc2)', sent1a.similarity(doc2))
|
print("Similarity", doc1.similarity(doc2))
|
||||||
print('sent1b.similarity(doc2)', sent1b.similarity(doc2))
|
|
||||||
|
|
||||||
|
|
||||||
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
|
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
|
||||||
|
@ -119,7 +131,8 @@ def main(mode, model_dir, train_loc, dev_loc,
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
train(model_dir, train_loc, dev_loc, shape, settings)
|
train(model_dir, train_loc, dev_loc, shape, settings)
|
||||||
elif mode == 'evaluate':
|
elif mode == 'evaluate':
|
||||||
evaluate(model_dir, dev_loc)
|
correct, total = evaluate(model_dir, dev_loc)
|
||||||
|
print(correct, '/', total, correct / total)
|
||||||
else:
|
else:
|
||||||
demo(model_dir)
|
demo(model_dir)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ from keras.models import Sequential, Model, model_from_json
|
||||||
from keras.regularizers import l2
|
from keras.regularizers import l2
|
||||||
from keras.optimizers import Adam
|
from keras.optimizers import Adam
|
||||||
from keras.layers.normalization import BatchNormalization
|
from keras.layers.normalization import BatchNormalization
|
||||||
|
from keras.layers.pooling import GlobalAveragePooling1D, GlobalMaxPooling1D
|
||||||
|
from keras.layers import Merge
|
||||||
|
|
||||||
|
|
||||||
def build_model(vectors, shape, settings):
|
def build_model(vectors, shape, settings):
|
||||||
|
@ -29,11 +31,11 @@ def build_model(vectors, shape, settings):
|
||||||
align = _SoftAlignment(max_length, nr_hidden)
|
align = _SoftAlignment(max_length, nr_hidden)
|
||||||
compare = _Comparison(max_length, nr_hidden, dropout=settings['dropout'])
|
compare = _Comparison(max_length, nr_hidden, dropout=settings['dropout'])
|
||||||
entail = _Entailment(nr_hidden, nr_class, dropout=settings['dropout'])
|
entail = _Entailment(nr_hidden, nr_class, dropout=settings['dropout'])
|
||||||
|
|
||||||
# Declare the model as a computational graph.
|
# Declare the model as a computational graph.
|
||||||
sent1 = embed(ids1) # Shape: (i, n)
|
sent1 = embed(ids1) # Shape: (i, n)
|
||||||
sent2 = embed(ids2) # Shape: (j, n)
|
sent2 = embed(ids2) # Shape: (j, n)
|
||||||
|
|
||||||
if settings['gru_encode']:
|
if settings['gru_encode']:
|
||||||
sent1 = encode(sent1)
|
sent1 = encode(sent1)
|
||||||
sent2 = encode(sent2)
|
sent2 = encode(sent2)
|
||||||
|
@ -42,12 +44,12 @@ def build_model(vectors, shape, settings):
|
||||||
|
|
||||||
align1 = align(sent2, attention)
|
align1 = align(sent2, attention)
|
||||||
align2 = align(sent1, attention, transpose=True)
|
align2 = align(sent1, attention, transpose=True)
|
||||||
|
|
||||||
feats1 = compare(sent1, align1)
|
feats1 = compare(sent1, align1)
|
||||||
feats2 = compare(sent2, align2)
|
feats2 = compare(sent2, align2)
|
||||||
|
|
||||||
scores = entail(feats1, feats2)
|
scores = entail(feats1, feats2)
|
||||||
|
|
||||||
# Now that we have the input/output, we can construct the Model object...
|
# Now that we have the input/output, we can construct the Model object...
|
||||||
model = Model(input=[ids1, ids2], output=[scores])
|
model = Model(input=[ids1, ids2], output=[scores])
|
||||||
|
|
||||||
|
@ -93,7 +95,7 @@ class _StaticEmbedding(object):
|
||||||
def get_output_shape(shapes):
|
def get_output_shape(shapes):
|
||||||
print(shapes)
|
print(shapes)
|
||||||
return shapes[0]
|
return shapes[0]
|
||||||
mod_sent = self.mod_ids(sentence)
|
mod_sent = self.mod_ids(sentence)
|
||||||
tuning = self.tune(mod_sent)
|
tuning = self.tune(mod_sent)
|
||||||
#tuning = merge([tuning, mod_sent],
|
#tuning = merge([tuning, mod_sent],
|
||||||
# mode=lambda AB: AB[0] * (K.clip(K.cast(AB[1], 'float32'), 0, 1)),
|
# mode=lambda AB: AB[0] * (K.clip(K.cast(AB[1], 'float32'), 0, 1)),
|
||||||
|
@ -129,7 +131,7 @@ class _Attention(object):
|
||||||
self.model.add(Dense(nr_hidden, name='attend2',
|
self.model.add(Dense(nr_hidden, name='attend2',
|
||||||
init='he_normal', W_regularizer=l2(L2), activation='relu'))
|
init='he_normal', W_regularizer=l2(L2), activation='relu'))
|
||||||
self.model = TimeDistributed(self.model)
|
self.model = TimeDistributed(self.model)
|
||||||
|
|
||||||
def __call__(self, sent1, sent2):
|
def __call__(self, sent1, sent2):
|
||||||
def _outer(AB):
|
def _outer(AB):
|
||||||
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
|
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
|
||||||
|
@ -158,7 +160,7 @@ class _SoftAlignment(object):
|
||||||
return K.batch_dot(sm_att, mat)
|
return K.batch_dot(sm_att, mat)
|
||||||
return merge([attention, sentence], mode=_normalize_attention,
|
return merge([attention, sentence], mode=_normalize_attention,
|
||||||
output_shape=(self.max_length, self.nr_hidden)) # Shape: (i, n)
|
output_shape=(self.max_length, self.nr_hidden)) # Shape: (i, n)
|
||||||
|
|
||||||
|
|
||||||
class _Comparison(object):
|
class _Comparison(object):
|
||||||
def __init__(self, words, nr_hidden, L2=0.0, dropout=0.0):
|
def __init__(self, words, nr_hidden, L2=0.0, dropout=0.0):
|
||||||
|
@ -176,10 +178,12 @@ class _Comparison(object):
|
||||||
|
|
||||||
def __call__(self, sent, align, **kwargs):
|
def __call__(self, sent, align, **kwargs):
|
||||||
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
|
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
|
||||||
result = _GlobalSumPooling1D()(result, mask=self.words)
|
avged = GlobalAveragePooling1D()(result, mask=self.words)
|
||||||
result = BatchNormalization()(result)
|
maxed = GlobalMaxPooling1D()(result, mask=self.words)
|
||||||
|
merged = merge([avged, maxed])
|
||||||
|
result = BatchNormalization()(merged)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class _Entailment(object):
|
class _Entailment(object):
|
||||||
def __init__(self, nr_hidden, nr_out, dropout=0.0, L2=0.0):
|
def __init__(self, nr_hidden, nr_out, dropout=0.0, L2=0.0):
|
||||||
|
@ -251,7 +255,7 @@ def test_fit_model():
|
||||||
shape = (10, 16, 3)
|
shape = (10, 16, 3)
|
||||||
settings = {'lr': 0.001, 'dropout': 0.2, 'gru_encode':True}
|
settings = {'lr': 0.001, 'dropout': 0.2, 'gru_encode':True}
|
||||||
model = build_model(vectors, shape, settings)
|
model = build_model(vectors, shape, settings)
|
||||||
|
|
||||||
train_X = _generate_X(20, shape[0], vectors.shape[1])
|
train_X = _generate_X(20, shape[0], vectors.shape[1])
|
||||||
train_Y = _generate_Y(20, shape[2])
|
train_Y = _generate_Y(20, shape[2])
|
||||||
dev_X = _generate_X(15, shape[0], vectors.shape[1])
|
dev_X = _generate_X(15, shape[0], vectors.shape[1])
|
||||||
|
@ -261,6 +265,4 @@ def test_fit_model():
|
||||||
batch_size=4)
|
batch_size=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [build_model]
|
__all__ = [build_model]
|
||||||
|
|
|
@ -1,33 +1,40 @@
|
||||||
from keras.models import model_from_json
|
from keras.models import model_from_json
|
||||||
import numpy
|
import numpy
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
import json
|
||||||
|
from spacy.tokens.span import Span
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cPickle as pickle
|
||||||
|
except ImportError:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
class KerasSimilarityShim(object):
|
class KerasSimilarityShim(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, nlp, get_features=None):
|
def load(cls, path, nlp, get_features=None, max_length=100):
|
||||||
if get_features is None:
|
if get_features is None:
|
||||||
get_features = doc2ids
|
get_features = get_word_ids
|
||||||
with (path / 'config.json').open() as file_:
|
with (path / 'config.json').open() as file_:
|
||||||
config = json.load(file_)
|
model = model_from_json(file_.read())
|
||||||
model = model_from_json(config['model'])
|
|
||||||
with (path / 'model').open('rb') as file_:
|
with (path / 'model').open('rb') as file_:
|
||||||
weights = pickle.load(file_)
|
weights = pickle.load(file_)
|
||||||
embeddings = get_embeddings(nlp.vocab)
|
embeddings = get_embeddings(nlp.vocab)
|
||||||
model.set_weights([embeddings] + weights)
|
model.set_weights([embeddings] + weights)
|
||||||
return cls(model, get_features=get_features)
|
return cls(model, get_features=get_features, max_length=max_length)
|
||||||
|
|
||||||
def __init__(self, model, get_features=None):
|
def __init__(self, model, get_features=None, max_length=100):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.get_features = get_features
|
self.get_features = get_features
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
doc.user_hooks['similarity'] = self.predict
|
doc.user_hooks['similarity'] = self.predict
|
||||||
doc.user_span_hooks['similarity'] = self.predict
|
doc.user_span_hooks['similarity'] = self.predict
|
||||||
|
|
||||||
def predict(self, doc1, doc2):
|
def predict(self, doc1, doc2):
|
||||||
x1 = self.get_features(doc1)
|
x1 = self.get_features([doc1], max_length=self.max_length, tree_truncate=True)
|
||||||
x2 = self.get_features(doc2)
|
x2 = self.get_features([doc2], max_length=self.max_length, tree_truncate=True)
|
||||||
scores = self.model.predict([x1, x2])
|
scores = self.model.predict([x1, x2])
|
||||||
return scores[0]
|
return scores[0]
|
||||||
|
|
||||||
|
@ -45,7 +52,10 @@ def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100, nr
|
||||||
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
|
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if tree_truncate:
|
if tree_truncate:
|
||||||
queue = [sent.root for sent in doc.sents]
|
if isinstance(doc, Span):
|
||||||
|
queue = [doc.root]
|
||||||
|
else:
|
||||||
|
queue = [sent.root for sent in doc.sents]
|
||||||
else:
|
else:
|
||||||
queue = list(doc)
|
queue = list(doc)
|
||||||
words = []
|
words = []
|
||||||
|
@ -71,7 +81,9 @@ def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100, nr
|
||||||
|
|
||||||
|
|
||||||
def create_similarity_pipeline(nlp):
|
def create_similarity_pipeline(nlp):
|
||||||
return [SimilarityModel.load(
|
return [
|
||||||
nlp.path / 'similarity',
|
nlp.tagger,
|
||||||
nlp,
|
nlp.entity,
|
||||||
feature_extracter=get_features)]
|
nlp.parser,
|
||||||
|
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length=10)
|
||||||
|
]
|
||||||
|
|
Loading…
Reference in New Issue