write entity linking pipe to file and keep vocab consistent between kb and nlp

This commit is contained in:
svlandeg 2019-06-13 16:25:39 +02:00
parent b12001f368
commit 78dd3e11da
5 changed files with 226 additions and 93 deletions

View File

@ -40,8 +40,8 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
title_list = list(title_to_id.keys()) title_list = list(title_to_id.keys())
# TODO: remove this filter (just for quicker testing of code) # TODO: remove this filter (just for quicker testing of code)
# title_list = title_list[0:34200] title_list = title_list[0:342]
# title_to_id = {t: title_to_id[t] for t in title_list} title_to_id = {t: title_to_id[t] for t in title_list}
entity_list = [title_to_id[x] for x in title_list] entity_list = [title_to_id[x] for x in title_list]

View File

@ -6,6 +6,7 @@ import random
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy import spacy
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -22,41 +23,48 @@ ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv' ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb' KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb'
VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab' NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1'
NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
MIN_PAIR_OCC = 5 MIN_PAIR_OCC = 5
DOC_CHAR_CUTOFF = 300 DOC_CHAR_CUTOFF = 300
EPOCHS = 10 EPOCHS = 2
DROPOUT = 0.1 DROPOUT = 0.1
def run_pipeline(): def run_pipeline():
print("START", datetime.datetime.now()) print("START", datetime.datetime.now())
print() print()
nlp = spacy.load('en_core_web_lg') nlp_1 = spacy.load('en_core_web_lg')
my_kb = None nlp_2 = None
kb_1 = None
kb_2 = None
# one-time methods to create KB and write to file # one-time methods to create KB and write to file
to_create_prior_probs = False to_create_prior_probs = False
to_create_entity_counts = False to_create_entity_counts = False
to_create_kb = False to_create_kb = True
# read KB back in from file # read KB back in from file
to_read_kb = True to_read_kb = True
to_test_kb = False to_test_kb = True
# create training dataset # create training dataset
create_wp_training = False create_wp_training = False
# train the EL pipe # train the EL pipe
train_pipe = True train_pipe = True
measure_performance = False
# test the EL pipe on a simple example # test the EL pipe on a simple example
to_test_pipeline = True to_test_pipeline = True
# write the NLP object, read back in and test again
test_nlp_io = True
# STEP 1 : create prior probabilities from WP # STEP 1 : create prior probabilities from WP
# run only once ! # run only once !
if to_create_prior_probs: if to_create_prior_probs:
@ -75,7 +83,7 @@ def run_pipeline():
# run only once ! # run only once !
if to_create_kb: if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now()) print("STEP 3a: to_create_kb", datetime.datetime.now())
my_kb = kb_creator.create_kb(nlp, kb_1 = kb_creator.create_kb(nlp_1,
max_entities_per_alias=MAX_CANDIDATES, max_entities_per_alias=MAX_CANDIDATES,
min_occ=MIN_PAIR_OCC, min_occ=MIN_PAIR_OCC,
entity_def_output=ENTITY_DEFS, entity_def_output=ENTITY_DEFS,
@ -83,63 +91,66 @@ def run_pipeline():
count_input=ENTITY_COUNTS, count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB, prior_prob_input=PRIOR_PROB,
to_print=False) to_print=False)
print("kb entities:", my_kb.get_size_entities()) print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", my_kb.get_size_aliases()) print("kb aliases:", kb_1.get_size_aliases())
print() print()
print("STEP 3b: write KB", datetime.datetime.now()) print("STEP 3b: write KB and NLP", datetime.datetime.now())
my_kb.dump(KB_FILE) kb_1.dump(KB_FILE)
nlp.vocab.to_disk(VOCAB_DIR) nlp_1.to_disk(NLP_1_DIR)
print() print()
# STEP 4 : read KB back in from file # STEP 4 : read KB back in from file
if to_read_kb: if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now()) print("STEP 4: to_read_kb", datetime.datetime.now())
my_vocab = Vocab() # my_vocab = Vocab()
my_vocab.from_disk(VOCAB_DIR) # my_vocab.from_disk(VOCAB_DIR)
my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64) # TODO entity vectors # my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64)
my_kb.load_bulk(KB_FILE) nlp_2 = spacy.load(NLP_1_DIR)
print("kb entities:", my_kb.get_size_entities()) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
print("kb aliases:", my_kb.get_size_aliases()) kb_2.load_bulk(KB_FILE)
print("kb entities:", kb_2.get_size_entities())
print("kb aliases:", kb_2.get_size_aliases())
print() print()
# test KB # test KB
if to_test_kb: if to_test_kb:
run_el.run_kb_toy_example(kb=my_kb) run_el.run_kb_toy_example(kb=kb_2)
print() print()
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
if create_wp_training: if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now()) print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(kb=my_kb, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
# STEP 6: create the entity linking pipe # STEP 6: create the entity linking pipe
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 5000 train_limit = 10
dev_limit = 1000 dev_limit = 5
print("Training on", train_limit, "articles") print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles") print("Dev testing on", dev_limit, "articles")
print() print()
train_data = training_set_creator.read_training(nlp=nlp, train_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=False, dev=False,
limit=train_limit, limit=train_limit,
to_print=False) to_print=False)
dev_data = training_set_creator.read_training(nlp=nlp, dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=True, dev=True,
limit=dev_limit, limit=dev_limit,
to_print=False) to_print=False)
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF}) el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF})
nlp.add_pipe(el_pipe, last=True) el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"] other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
with nlp.disable_pipes(*other_pipes): # only train Entity Linking with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
nlp.begin_training() nlp_2.begin_training()
for itn in range(EPOCHS): for itn in range(EPOCHS):
random.shuffle(train_data) random.shuffle(train_data)
@ -147,11 +158,11 @@ def run_pipeline():
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0 batchnr = 0
with nlp.disable_pipes(*other_pipes): with nlp_2.disable_pipes(*other_pipes):
for batch in batches: for batch in batches:
try: try:
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp.update( nlp_2.update(
docs, docs,
golds, golds,
drop=DROPOUT, drop=DROPOUT,
@ -164,40 +175,62 @@ def run_pipeline():
losses['entity_linker'] = losses['entity_linker'] / batchnr losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
print() if measure_performance:
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print()
print() print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print()
# print(" measuring accuracy 1-1") # print(" measuring accuracy 1-1")
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_1 = _measure_accuracy(train_data, el_pipe) train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
print("train/dev acc combo:", round(train_acc_1_1, 2), round(dev_acc_1_1, 2)) print("train/dev acc combo:", round(train_acc_1_1, 2), round(dev_acc_1_1, 2))
# baseline using only prior probabilities # baseline using only prior probabilities
el_pipe.context_weight = 0 el_pipe.context_weight = 0
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_0_1 = _measure_accuracy(train_data, el_pipe) train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
print("train/dev acc prior:", round(train_acc_0_1, 2), round(dev_acc_0_1, 2)) print("train/dev acc prior:", round(train_acc_0_1, 2), round(dev_acc_0_1, 2))
# using only context # using only context
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 0 el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe) train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2)) print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
print() print()
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
print() print()
run_el_toy_example(kb=my_kb, nlp=nlp) run_el_toy_example(nlp=nlp_2)
print() print()
if test_nlp_io:
print()
print("STEP 9: testing NLP IO", datetime.datetime.now())
print()
print("writing to", NLP_2_DIR)
print(" vocab len nlp_2", len(nlp_2.vocab))
print(" vocab len kb_2", len(kb_2.vocab))
nlp_2.to_disk(NLP_2_DIR)
print()
print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR)
print(" vocab len nlp_3", len(nlp_3.vocab))
for pipe_name, pipe in nlp_3.pipeline:
if pipe_name == "entity_linker":
print(" vocab len kb_3", len(pipe.kb.vocab))
print()
print("running toy example with NLP 2")
run_el_toy_example(nlp=nlp_3)
print() print()
print("STOP", datetime.datetime.now()) print("STOP", datetime.datetime.now())
@ -239,7 +272,7 @@ def _measure_accuracy(data, el_pipe):
return acc return acc
def run_el_toy_example(nlp, kb): def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel. " \ "Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is the man Arthur Dent, " \ "The main character in Doug's novel is the man Arthur Dent, " \
@ -261,4 +294,4 @@ def run_el_toy_example(nlp, kb):
if __name__ == "__main__": if __name__ == "__main__":
run_pipeline() run_pipeline()

View File

@ -2,6 +2,8 @@
# cython: profile=True # cython: profile=True
# coding: utf8 # coding: utf8
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path, WindowsPath
from cpython.exc cimport PyErr_CheckSignals from cpython.exc cimport PyErr_CheckSignals
from spacy import util from spacy import util
@ -389,6 +391,8 @@ cdef class Writer:
def __init__(self, object loc): def __init__(self, object loc):
if path.exists(loc): if path.exists(loc):
assert not path.isdir(loc), "%s is directory." % loc assert not path.isdir(loc), "%s is directory." % loc
if isinstance(loc, Path):
loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
assert self._fp != NULL assert self._fp != NULL
@ -431,6 +435,8 @@ cdef class Reader:
def __init__(self, object loc): def __init__(self, object loc):
assert path.exists(loc) assert path.exists(loc)
assert not path.isdir(loc) assert not path.isdir(loc)
if isinstance(loc, Path):
loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'rb') self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp: if not self._fp:

View File

@ -11,6 +11,7 @@ from copy import copy, deepcopy
from thinc.neural import Model from thinc.neural import Model
import srsly import srsly
from spacy.kb import KnowledgeBase
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
@ -809,6 +810,14 @@ class Language(object):
# Convert to list here in case exclude is (default) tuple # Convert to list here in case exclude is (default) tuple
exclude = list(exclude) + ["vocab"] exclude = list(exclude) + ["vocab"]
util.from_disk(path, deserializers, exclude) util.from_disk(path, deserializers, exclude)
# download the KB for the entity linking component - requires the vocab
for pipe_name, pipe in self.pipeline:
if pipe_name == "entity_linker":
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"])
kb.load_bulk(path / pipe_name / "kb")
pipe.set_kb(kb)
self._path = path self._path = path
return self return self

View File

@ -14,6 +14,7 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown from ..syntax.ner cimport BiluoPushDown
@ -1077,7 +1078,7 @@ class EntityLinker(Pipe):
hidden_width = cfg.get("hidden_width", 32) hidden_width = cfg.get("hidden_width", 32)
article_width = cfg.get("article_width", 128) article_width = cfg.get("article_width", 128)
sent_width = cfg.get("sent_width", 64) sent_width = cfg.get("sent_width", 64)
entity_width = cfg["kb"].entity_vector_length entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg) article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg) sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
@ -1089,34 +1090,41 @@ class EntityLinker(Pipe):
return article_encoder, sent_encoder, mention_encoder return article_encoder, sent_encoder, mention_encoder
def __init__(self, **cfg): def __init__(self, **cfg):
self.article_encoder = True
self.sent_encoder = True
self.mention_encoder = True self.mention_encoder = True
self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.kb = self.cfg["kb"] self.doc_cutoff = self.cfg.get("doc_cutoff", 150)
self.doc_cutoff = self.cfg["doc_cutoff"]
def use_avg_params(self):
# Modify the pipe's encoders/models, to use their average parameter values.
# TODO: this doesn't work yet because there's no exit method
self.article_encoder.use_params(self.sgd_article.averages)
self.sent_encoder.use_params(self.sgd_sent.averages)
self.mention_encoder.use_params(self.sgd_mention.averages)
def set_kb(self, kb):
self.kb = kb
def require_model(self): def require_model(self):
# Raise an error if the component's model is not initialized. # Raise an error if the component's model is not initialized.
if getattr(self, "mention_encoder", None) in (None, True, False): if getattr(self, "mention_encoder", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name)) raise ValueError(Errors.E109.format(name=self.name))
def require_kb(self):
# Raise an error if the knowledge base is not initialized.
if getattr(self, "kb", None) in (None, True, False):
# TODO: custom error
raise ValueError(Errors.E109.format(name=self.name))
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
self.require_kb()
self.cfg["entity_width"] = self.kb.entity_vector_length
if self.mention_encoder is True: if self.mention_encoder is True:
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg) self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
self.sgd_article = create_default_optimizer(self.article_encoder.ops) self.sgd_article = create_default_optimizer(self.article_encoder.ops)
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops) self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops) self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
return self.sgd_article return self.sgd_article
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
self.require_model() self.require_model()
self.require_kb()
if len(docs) != len(golds): if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs), raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
@ -1220,6 +1228,7 @@ class EntityLinker(Pipe):
def predict(self, docs): def predict(self, docs):
self.require_model() self.require_model()
self.require_kb()
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
@ -1228,30 +1237,32 @@ class EntityLinker(Pipe):
final_kb_ids = list() final_kb_ids = list()
for i, article_doc in enumerate(docs): for i, article_doc in enumerate(docs):
doc_encoding = self.article_encoder([article_doc]) if len(article_doc) > 0:
for ent in article_doc.ents: doc_encoding = self.article_encoder([article_doc])
sent_doc = ent.sent.as_doc() for ent in article_doc.ents:
sent_encoding = self.sent_encoder([sent_doc]) sent_doc = ent.sent.as_doc()
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] if len(sent_doc) > 0:
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]])) sent_encoding = self.sent_encoder([sent_doc])
mention_enc_t = np.transpose(mention_encoding) concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
mention_enc_t = np.transpose(mention_encoding)
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if candidates: if candidates:
scores = list() scores = list()
for c in candidates: for c in candidates:
prior_prob = c.prior_prob * self.prior_weight prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_ kb_id = c.entity_
entity_encoding = c.entity_vector entity_encoding = c.entity_vector
sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ? score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score) scores.append(score)
# TODO: thresholding # TODO: thresholding
best_index = scores.index(max(scores)) best_index = scores.index(max(scores))
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_entities.append(ent) final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids return final_entities, final_kb_ids
@ -1260,6 +1271,80 @@ class EntityLinker(Pipe):
for token in entity: for token in entity:
token.ent_kb_id_ = kb_id token.ent_kb_id_ = kb_id
def to_bytes(self, exclude=tuple(), **kwargs):
"""Serialize the pipe to a bytestring.
exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
"""
serialize = OrderedDict()
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["kb"] = self.kb.to_bytes # TODO
if self.mention_encoder not in (True, False, None):
serialize["article_encoder"] = self.article_encoder.to_bytes
serialize["sent_encoder"] = self.sent_encoder.to_bytes
serialize["mention_encoder"] = self.mention_encoder.to_bytes
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
"""Load the pipe from a bytestring."""
deserialize = OrderedDict()
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO
deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b)
deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b)
deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b)
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
"""Serialize the pipe to disk."""
serialize = OrderedDict()
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.dump(p)
if self.mention_encoder not in (None, True, False):
serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes())
serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes())
serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes())
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
"""Load the pipe from disk."""
def load_article_encoder(p):
if self.article_encoder is True:
self.article_encoder, _, _ = self.Model(**self.cfg)
self.article_encoder.from_bytes(p.open("rb").read())
def load_sent_encoder(p):
if self.sent_encoder is True:
_, self.sent_encoder, _ = self.Model(**self.cfg)
self.sent_encoder.from_bytes(p.open("rb").read())
def load_mention_encoder(p):
if self.mention_encoder is True:
_, _, self.mention_encoder = self.Model(**self.cfg)
self.mention_encoder.from_bytes(p.open("rb").read())
deserialize = OrderedDict()
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["article_encoder"] = load_article_encoder
deserialize["sent_encoder"] = load_sent_encoder
deserialize["mention_encoder"] = load_mention_encoder
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_disk(path, deserialize, exclude)
return self
def rehearse(self, docs, sgd=None, losses=None, **config):
# TODO
pass
def add_label(self, label):
pass
class Sentencizer(object): class Sentencizer(object):
"""Segment the Doc into sentences using a rule-based strategy. """Segment the Doc into sentences using a rule-based strategy.