mirror of https://github.com/explosion/spaCy.git
Update conllu script
This commit is contained in:
parent
9c8a0f6eba
commit
a26e399f84
|
@ -28,6 +28,8 @@ def prevent_bad_sentences(doc):
|
|||
token.is_sent_start = False
|
||||
elif not token.nbor(-1).is_punct:
|
||||
token.is_sent_start = False
|
||||
elif token.nbor(-1).is_left_punct:
|
||||
token.is_sent_start = False
|
||||
return doc
|
||||
|
||||
|
||||
|
@ -99,7 +101,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
|||
# cs is conllu sent, ct is conllu token
|
||||
docs = []
|
||||
golds = []
|
||||
for text, cd in zip(paragraphs, conllu):
|
||||
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
||||
doc_words = []
|
||||
doc_tags = []
|
||||
doc_heads = []
|
||||
|
@ -140,7 +142,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
|||
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
||||
heads=doc_heads, deps=doc_deps,
|
||||
entities=doc_ents))
|
||||
if limit and len(docs) >= limit:
|
||||
if limit and doc_id >= limit:
|
||||
break
|
||||
return docs, golds
|
||||
|
||||
|
@ -188,7 +190,14 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
|||
with open(conllu_loc) as conllu_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=oracle_segments)
|
||||
if not joint_sbd:
|
||||
if joint_sbd:
|
||||
sbd = nlp.create_pipe('sentencizer')
|
||||
for doc in docs:
|
||||
doc = sbd(doc)
|
||||
for sent in doc.sents:
|
||||
sent[0].is_sent_start = True
|
||||
#docs = (prevent_bad_sentences(doc) for doc in docs)
|
||||
else:
|
||||
sbd = nlp.create_pipe('sentencizer')
|
||||
for doc in docs:
|
||||
doc = sbd(doc)
|
||||
|
@ -245,7 +254,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
with open(conllu_train_loc) as conllu_file:
|
||||
with open(text_train_loc) as text_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=False, raw_text=True)
|
||||
oracle_segments=True, raw_text=True,
|
||||
limit=None)
|
||||
print("Create parser")
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
|
@ -266,7 +276,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
print("Begin training")
|
||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||
# at the beginning of training.
|
||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2),
|
||||
spacy.util.env_opt('batch_to', 2),
|
||||
spacy.util.env_opt('batch_compound', 1.001))
|
||||
for i in range(30):
|
||||
|
@ -278,6 +288,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
if not batch:
|
||||
continue
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs]
|
||||
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=0.2, losses=losses)
|
||||
|
@ -296,6 +307,5 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
print_conllu(dev_docs, file_)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
||||
|
|
Loading…
Reference in New Issue