From 8e7414daceed0a4f9d58dd12ad2f22be7f0097be Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Sun, 27 Oct 2019 16:01:32 +0100 Subject: [PATCH] Match pop with append for training format (#4516) * trying to fix script - not succesful yet * match pop() with extend() to avoid changing the data * few more pop-extend fixes * reinsert deleted print statement * fix print statement * add last tested version * append instead of extend * add in few comments * quick fix for 4402 + unit test * fixing number of docs (not counting cats) * more fixes * fix len * print tmp file instead of using data from examples dir * print tmp file instead of using data from examples dir (2) --- examples/training/ner_multitask_objective.py | 35 ++++--- spacy/gold.pyx | 41 ++++++--- spacy/language.py | 3 +- spacy/pipeline/pipes.pyx | 7 +- spacy/syntax/arc_eager.pyx | 3 +- spacy/syntax/ner.pyx | 3 +- spacy/syntax/nn_parser.pyx | 3 +- spacy/syntax/nonproj.pyx | 6 ++ spacy/tests/regression/test_issue4402.py | 97 ++++++++++++++++++++ 9 files changed, 164 insertions(+), 34 deletions(-) create mode 100644 spacy/tests/regression/test_issue4402.py diff --git a/examples/training/ner_multitask_objective.py b/examples/training/ner_multitask_objective.py index 5d44ed649..fc6eb42c1 100644 --- a/examples/training/ner_multitask_objective.py +++ b/examples/training/ner_multitask_objective.py @@ -18,7 +18,7 @@ during training. We discard the auxiliary model before run-time. The specific example here is not necessarily a good idea --- but it shows how an arbitrary objective function for some word can be used. -Developed and tested for spaCy 2.0.6 +Developed for spaCy 2.0.6 and last tested for 2.2.2 """ import random import plac @@ -26,6 +26,8 @@ import spacy import os.path from spacy.gold import read_json_file, GoldParse +from spacy.tokens import Doc + random.seed(0) PWD = os.path.dirname(__file__) @@ -56,22 +58,29 @@ def main(n_iter=10): ner.add_multitask_objective(get_position_label) nlp.add_pipe(ner) - print("Create data", len(TRAIN_DATA)) + _, sents = TRAIN_DATA[0] + print("Create data, # of sentences =", len(sents) - 1) # not counting the cats attribute optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA) for itn in range(n_iter): random.shuffle(TRAIN_DATA) losses = {} - for text, annot_brackets in TRAIN_DATA: - annotations, _ = annot_brackets - doc = nlp.make_doc(text) - gold = GoldParse.from_annot_tuples(doc, annotations[0]) - nlp.update( - [doc], # batch of texts - [gold], # batch of annotations - drop=0.2, # dropout - make it harder to memorise data - sgd=optimizer, # callable to update weights - losses=losses, - ) + + for raw_text, annots_brackets in TRAIN_DATA: + cats = annots_brackets.pop() + for annotations, _ in annots_brackets: + annotations.append(cats) # temporarily add it here for from_annot_tuples to work + doc = Doc(nlp.vocab, words=annotations[1]) + gold = GoldParse.from_annot_tuples(doc, annotations) + annotations.pop() # restore data + + nlp.update( + [doc], # batch of texts + [gold], # batch of annotations + drop=0.2, # dropout - make it harder to memorise data + sgd=optimizer, # callable to update weights + losses=losses, + ) + annots_brackets.append(cats) # restore data print(losses.get("nn_labeller", 0.0), losses["ner"]) # test the trained model diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 4d86d4e86..817b059ce 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -55,22 +55,22 @@ def tags_to_entities(tags): def merge_sents(sents): - m_deps = [[], [], [], [], [], []] + m_sents = [[], [], [], [], [], []] m_brackets = [] m_cats = sents.pop() i = 0 for (ids, words, tags, heads, labels, ner), brackets in sents: - m_deps[0].extend(id_ + i for id_ in ids) - m_deps[1].extend(words) - m_deps[2].extend(tags) - m_deps[3].extend(head + i for head in heads) - m_deps[4].extend(labels) - m_deps[5].extend(ner) + m_sents[0].extend(id_ + i for id_ in ids) + m_sents[1].extend(words) + m_sents[2].extend(tags) + m_sents[3].extend(head + i for head in heads) + m_sents[4].extend(labels) + m_sents[5].extend(ner) m_brackets.extend((b["first"] + i, b["last"] + i, b["label"]) for b in brackets) i += len(ids) - m_deps.append(m_cats) - return [(m_deps, m_brackets)] + sents.append(m_cats) # restore original data + return [[(m_sents, m_brackets)], m_cats] _NORM_MAP = {"``": '"', "''": '"'} @@ -248,6 +248,7 @@ class GoldCorpus(object): if self.limit and i >= self.limit: break i += 1 + paragraph_tuples.append(cats) # restore original data return n def train_docs(self, nlp, gold_preproc=False, max_length=None, @@ -288,26 +289,36 @@ class GoldCorpus(object): @classmethod def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, noise_level=0.0, orth_variant_level=0.0): + cats = paragraph_tuples.pop() if raw_text is not None: raw_text, paragraph_tuples = make_orth_variants(nlp, raw_text, paragraph_tuples, orth_variant_level=orth_variant_level) raw_text = add_noise(raw_text, noise_level) - return [nlp.make_doc(raw_text)], paragraph_tuples + result = [nlp.make_doc(raw_text)], paragraph_tuples else: docs = [] raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level) - return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level)) + result = [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level)) for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples + paragraph_tuples.append(cats) + return result @classmethod def _make_golds(cls, docs, paragraph_tuples, make_projective): + cats = paragraph_tuples.pop() if len(docs) != len(paragraph_tuples): n_annots = len(paragraph_tuples) raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots)) - return [GoldParse.from_annot_tuples(doc, sent_tuples, - make_projective=make_projective) - for doc, (sent_tuples, brackets) - in zip(docs, paragraph_tuples)] + result = [] + for doc, brack_annot in zip(docs, paragraph_tuples): + if len(brack_annot) == 1: + brack_annot = brack_annot[0] + sent_tuples, brackets = brack_annot + sent_tuples.append(cats) + result.append(GoldParse.from_annot_tuples(doc, sent_tuples, make_projective=make_projective)) + sent_tuples.pop() + paragraph_tuples.append(cats) + return result def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0): diff --git a/spacy/language.py b/spacy/language.py index a7d1f3a70..c1aaf054c 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -598,10 +598,11 @@ class Language(object): # Populate vocab else: for _, annots_brackets in get_gold_tuples(): - _ = annots_brackets.pop() + cats = annots_brackets.pop() for annots, _ in annots_brackets: for word in annots[1]: _ = self.vocab[word] # noqa: F841 + annots_brackets.append(cats) # restore original data if cfg.get("device", -1) >= 0: util.use_gpu(cfg["device"]) if self.vocab.vectors.data.shape[1] >= 1: diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index e33c6259b..232241772 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -517,7 +517,7 @@ class Tagger(Pipe): orig_tag_map = dict(self.vocab.morphology.tag_map) new_tag_map = OrderedDict() for raw_text, annots_brackets in get_gold_tuples(): - _ = annots_brackets.pop() + cats = annots_brackets.pop() for annots, brackets in annots_brackets: ids, words, tags, heads, deps, ents = annots for tag in tags: @@ -525,6 +525,7 @@ class Tagger(Pipe): new_tag_map[tag] = orig_tag_map[tag] else: new_tag_map[tag] = {POS: X} + annots_brackets.append(cats) # restore original data cdef Vocab vocab = self.vocab if new_tag_map: vocab.morphology = Morphology(vocab.strings, new_tag_map, @@ -703,12 +704,14 @@ class MultitaskObjective(Tagger): sgd=None, **kwargs): gold_tuples = nonproj.preprocess_training_data(get_gold_tuples()) for raw_text, annots_brackets in gold_tuples: + cats = annots_brackets.pop() for annots, brackets in annots_brackets: ids, words, tags, heads, deps, ents = annots for i in range(len(ids)): label = self.make_label(i, words, tags, heads, deps, ents) if label is not None and label not in self.labels: self.labels[label] = len(self.labels) + annots_brackets.append(cats) if self.model is True: token_vector_width = util.env_opt("token_vector_width") self.model = self.Model(len(self.labels), tok2vec=tok2vec) @@ -1035,7 +1038,7 @@ class TextCategorizer(Pipe): def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): for raw_text, annots_brackets in get_gold_tuples(): - cats = annots_brackets.pop() + cats = annots_brackets[-1] for cat in cats: self.add_label(cat) if self.model is True: diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 5a7355061..7d0771478 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -342,7 +342,7 @@ cdef class ArcEager(TransitionSystem): actions[RIGHT][label] = 1 actions[REDUCE][label] = 1 for raw_text, sents in kwargs.get('gold_parses', []): - _ = sents.pop() + cats = sents.pop() for (ids, words, tags, heads, labels, iob), ctnts in sents: heads, labels = nonproj.projectivize(heads, labels) for child, head, label in zip(ids, heads, labels): @@ -356,6 +356,7 @@ cdef class ArcEager(TransitionSystem): elif head > child: actions[LEFT][label] += 1 actions[SHIFT][''] += 1 + sents.append(cats) # restore original data if min_freq is not None: for action, label_freqs in actions.items(): for label, freq in list(label_freqs.items()): diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 3bd096463..b059dfc05 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -73,13 +73,14 @@ cdef class BiluoPushDown(TransitionSystem): actions[action][entity_type] = 1 moves = ('M', 'B', 'I', 'L', 'U') for raw_text, sents in kwargs.get('gold_parses', []): - _ = sents.pop() + cats = sents.pop() for (ids, words, tags, heads, labels, biluo), _ in sents: for i, ner_tag in enumerate(biluo): if ner_tag != 'O' and ner_tag != '-': _, label = ner_tag.split('-', 1) for action in (BEGIN, IN, LAST, UNIT): actions[action][label] += 1 + sents.append(cats) # restore original data return actions @property diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 92168631c..6c8241aec 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -606,12 +606,13 @@ cdef class Parser: doc_sample = [] gold_sample = [] for raw_text, annots_brackets in islice(get_gold_tuples(), 1000): - _ = annots_brackets.pop() + cats = annots_brackets.pop() for annots, brackets in annots_brackets: ids, words, tags, heads, deps, ents = annots doc_sample.append(Doc(self.vocab, words=words)) gold_sample.append(GoldParse(doc_sample[-1], words=words, tags=tags, heads=heads, deps=deps, entities=ents)) + annots_brackets.append(cats) # restore original data self.model.begin_training(doc_sample, gold_sample) if pipeline is not None: self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg) diff --git a/spacy/syntax/nonproj.pyx b/spacy/syntax/nonproj.pyx index 53e8a9cfe..1665c7929 100644 --- a/spacy/syntax/nonproj.pyx +++ b/spacy/syntax/nonproj.pyx @@ -97,6 +97,7 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30): freqs = {} for raw_text, sents in gold_tuples: prepro_sents = [] + cats = sents.pop() for (ids, words, tags, heads, labels, iob), ctnts in sents: proj_heads, deco_labels = projectivize(heads, labels) # set the label to ROOT for each root dependent @@ -109,6 +110,8 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30): freqs[label] = freqs.get(label, 0) + 1 prepro_sents.append( ((ids, words, tags, proj_heads, deco_labels, iob), ctnts)) + sents.append(cats) + prepro_sents.append(cats) preprocessed.append((raw_text, prepro_sents)) if label_freq_cutoff > 0: return _filter_labels(preprocessed, label_freq_cutoff, freqs) @@ -209,6 +212,7 @@ def _filter_labels(gold_tuples, cutoff, freqs): filtered = [] for raw_text, sents in gold_tuples: filtered_sents = [] + cats = sents.pop() for (ids, words, tags, heads, labels, iob), ctnts in sents: filtered_labels = [] for label in labels: @@ -218,5 +222,7 @@ def _filter_labels(gold_tuples, cutoff, freqs): filtered_labels.append(label) filtered_sents.append( ((ids, words, tags, heads, filtered_labels, iob), ctnts)) + sents.append(cats) + filtered_sents.append(cats) filtered.append((raw_text, filtered_sents)) return filtered diff --git a/spacy/tests/regression/test_issue4402.py b/spacy/tests/regression/test_issue4402.py new file mode 100644 index 000000000..d213253ed --- /dev/null +++ b/spacy/tests/regression/test_issue4402.py @@ -0,0 +1,97 @@ +# coding: utf8 +from __future__ import unicode_literals + +import srsly +from spacy.gold import GoldCorpus, json_to_tuple + +from spacy.lang.en import English +from spacy.tests.util import make_tempdir + + +def test_issue4402(): + nlp = English() + with make_tempdir() as tmpdir: + print("temp", tmpdir) + json_path = tmpdir / "test4402.json" + srsly.write_json(json_path, json_data) + + corpus = GoldCorpus(str(json_path), str(json_path)) + + train_docs = list(corpus.train_docs(nlp, gold_preproc=True, max_length=0)) + # assert that the data got split into 4 sentences + assert len(train_docs) == 4 + + +json_data = [ + { + "id": 0, + "paragraphs": [ + { + "raw": "How should I cook bacon in an oven?\nI've heard of people cooking bacon in an oven.", + "sentences": [ + { + "tokens": [ + {"id": 0, "orth": "How", "ner": "O"}, + {"id": 1, "orth": "should", "ner": "O"}, + {"id": 2, "orth": "I", "ner": "O"}, + {"id": 3, "orth": "cook", "ner": "O"}, + {"id": 4, "orth": "bacon", "ner": "O"}, + {"id": 5, "orth": "in", "ner": "O"}, + {"id": 6, "orth": "an", "ner": "O"}, + {"id": 7, "orth": "oven", "ner": "O"}, + {"id": 8, "orth": "?", "ner": "O"}, + ], + "brackets": [], + }, + { + "tokens": [ + {"id": 9, "orth": "\n", "ner": "O"}, + {"id": 10, "orth": "I", "ner": "O"}, + {"id": 11, "orth": "'ve", "ner": "O"}, + {"id": 12, "orth": "heard", "ner": "O"}, + {"id": 13, "orth": "of", "ner": "O"}, + {"id": 14, "orth": "people", "ner": "O"}, + {"id": 15, "orth": "cooking", "ner": "O"}, + {"id": 16, "orth": "bacon", "ner": "O"}, + {"id": 17, "orth": "in", "ner": "O"}, + {"id": 18, "orth": "an", "ner": "O"}, + {"id": 19, "orth": "oven", "ner": "O"}, + {"id": 20, "orth": ".", "ner": "O"}, + ], + "brackets": [], + }, + ], + "cats": [ + {"label": "baking", "value": 1.0}, + {"label": "not_baking", "value": 0.0}, + ], + }, + { + "raw": "What is the difference between white and brown eggs?\n", + "sentences": [ + { + "tokens": [ + {"id": 0, "orth": "What", "ner": "O"}, + {"id": 1, "orth": "is", "ner": "O"}, + {"id": 2, "orth": "the", "ner": "O"}, + {"id": 3, "orth": "difference", "ner": "O"}, + {"id": 4, "orth": "between", "ner": "O"}, + {"id": 5, "orth": "white", "ner": "O"}, + {"id": 6, "orth": "and", "ner": "O"}, + {"id": 7, "orth": "brown", "ner": "O"}, + {"id": 8, "orth": "eggs", "ner": "O"}, + {"id": 9, "orth": "?", "ner": "O"}, + ], + "brackets": [], + }, + {"tokens": [{"id": 10, "orth": "\n", "ner": "O"}], "brackets": []}, + ], + "cats": [ + {"label": "baking", "value": 0.0}, + {"label": "not_baking", "value": 1.0}, + ], + }, + ], + } +] +