diff --git a/bin/parser/train.py b/bin/parser/train.py index 628caf515..dc6875733 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -26,8 +26,21 @@ from spacy.syntax.conll import GoldParse from spacy.scorer import Scorer +def add_noise(c, noise_level): + if random.random() >= noise_level: + return c + elif c == ' ': + return '\n' + elif c == '\n': + return ' ' + elif c in ['.', "'", "!", "?"]: + return '' + else: + return c.lower() + + def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0, - gold_preproc=False, n_sents=0): + gold_preproc=False, n_sents=0, corruption_level=0): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') ner_model_dir = path.join(model_dir, 'ner') @@ -55,15 +68,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0 print "Itn.\tUAS\tNER F.\tTag %\tToken %" for itn in range(n_iter): scorer = Scorer() - for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples: + for raw_text, annot_tuples, ctnt in gold_tuples: + raw_text = ''.join(add_noise(c, corruption_level) for c in raw_text) tokens = nlp(raw_text, merge_mwes=False) gold = GoldParse(tokens, annot_tuples) scorer.score(tokens, gold, verbose=False) - - if gold_preproc: - sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text] - else: - sents = [nlp.tokenizer(raw_text)] + assert not gold_preproc + sents = [nlp.tokenizer(raw_text)] for tokens in sents: gold = GoldParse(tokens, annot_tuples) nlp.tagger(tokens) @@ -90,7 +101,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=True) assert not gold_preproc nlp = Language(data_dir=model_dir) scorer = Scorer() - for raw_text, segmented_text, annot_tuples, brackets in gold_tuples: + for raw_text, annot_tuples, brackets in gold_tuples: tokens = nlp(raw_text, merge_mwes=False) gold = GoldParse(tokens, annot_tuples) scorer.score(tokens, gold, verbose=verbose) @@ -111,7 +122,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc): return scorer -def get_sents(json_dir, section): +def get_sents(json_loc): if path.exists(path.join(json_dir, section + '.json')): for sent in read_json_file(path.join(json_dir, section + '.json')): yield sent @@ -131,21 +142,24 @@ def get_sents(json_dir, section): @plac.annotations( - json_dir=("Annotated JSON files directory",), + train_loc=("Location of training json file"), + dev_loc=("Location of development json file"), + corruption_level=("Amount of noise to add to training data", "option", "c", float), model_dir=("Location of output model directory",), out_loc=("Out location", "option", "o", str), n_sents=("Number of training sentences", "option", "n", int), verbose=("Verbose error reporting", "flag", "v", bool), debug=("Debug mode", "flag", "d", bool) ) -def main(json_dir, model_dir, n_sents=0, out_loc="", verbose=False, - debug=False): - train(English, list(get_sents(json_dir, 'train')), model_dir, +def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False, + debug=False, corruption_level=0.0): + train(English, read_json_file(train_loc), model_dir, feat_set='basic' if not debug else 'debug', - gold_preproc=False, n_sents=n_sents) + gold_preproc=False, n_sents=n_sents, + corruption_level=corruption_level) if out_loc: write_parses(English, dev_loc, model_dir, out_loc) - scorer = evaluate(English, list(get_sents(json_dir, 'dev')), + scorer = evaluate(English, read_json_file(dev_loc), model_dir, gold_preproc=False, verbose=verbose) print 'TOK', 100-scorer.token_acc print 'POS', scorer.tags_acc diff --git a/bin/prepare_treebank.py b/bin/prepare_treebank.py index 8b23f3670..c2f765fa6 100644 --- a/bin/prepare_treebank.py +++ b/bin/prepare_treebank.py @@ -34,44 +34,30 @@ def _iter_raw_files(raw_loc): yield f -def _get_word_indices(raw_sent, word_idx, offset): - indices = {} - for piece in raw_sent.split(''): - for match in re.finditer(r'\S+', piece): - indices[word_idx] = offset + match.start() - word_idx += 1 - offset += len(piece) - return indices, word_idx, offset + 1 - - def format_doc(section, filename, raw_paras, ptb_loc, dep_loc): ptb_sents = read_ptb.split(open(ptb_loc).read()) dep_sents = read_conll.split(open(dep_loc).read()) assert len(ptb_sents) == len(dep_sents) - word_idx = 0 i = 0 doc = {'id': filename, 'paragraphs': []} for raw_sents in raw_paras: - para = {'raw': ' '.join(sent.replace('', '') for sent in raw_sents), - 'segmented': ''.join(raw_sents), - 'sents': [], - 'tokens': [], - 'brackets': []} + para = { + 'raw': ' '.join(sent.replace('', '') for sent in raw_sents), + 'sents': [], + 'tokens': [], + 'brackets': []} offset = 0 for raw_sent in raw_sents: - words = raw_sent.replace('', ' ').split() - para['sents'].append(offset) _, brackets = read_ptb.parse(ptb_sents[i], strip_bad_periods=True) _, annot = read_conll.parse(dep_sents[i], strip_bad_periods=True) - indices, word_idx, offset = _get_word_indices(raw_sent, 0, offset) - for j, token in enumerate(annot): + for token_id, token in enumerate(annot): try: - head = indices[token['head']] if token['head'] != -1 else -1 + head = (token['head'] + offset) if token['head'] != -1 else -1 para['tokens'].append({ - 'start': indices[token['id']], - 'orth': words[j], + 'id': offset + token_id, + 'orth': token['word'], 'tag': token['tag'], 'head': head, 'dep': token['dep']}) @@ -80,9 +66,11 @@ def format_doc(section, filename, raw_paras, ptb_loc, dep_loc): for label, start, end in brackets: if start != end: para['brackets'].append({'label': label, - 'start': indices[start], - 'end': indices[end-1]}) + 'start': start + offset, + 'end': (end-1) + offset}) i += 1 + offset += len(annot) + para['sents'].append(offset) doc['paragraphs'].append(para) return doc diff --git a/setup.py b/setup.py index ff36b4f3a..837d8923f 100644 --- a/setup.py +++ b/setup.py @@ -147,7 +147,7 @@ def main(modules, is_pypy): MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings', 'spacy.lexeme', 'spacy.vocab', 'spacy.tokens', 'spacy.spans', - 'spacy.morphology', + 'spacy.morphology', 'spacy.munge.alignment', 'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs', 'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax._state', 'spacy.syntax.transition_system',