diff --git a/spacy/pos_util.py b/spacy/pos_util.py index 667f2d9ac..7d4a36ac6 100644 --- a/spacy/pos_util.py +++ b/spacy/pos_util.py @@ -6,56 +6,35 @@ from .en import EN from .pos import Tagger -def realign_tagged(token_rules, tagged_line, sep='/'): - words, pos = zip(*[token.rsplit(sep, 1) for token in tagged_line.split()]) - positions = util.detokenize(token_rules, words) - aligned = [] - for group in positions: - w_group = [words[i] for i in group] - p_group = [pos[i] for i in group] - aligned.append(''.join(w_group) + sep + '_'.join(p_group)) - return ' '.join(aligned) - - -def read_tagged(detoken_rules, file_, sep='/'): - sentences = [] - for line in file_: - if not line.strip(): +def read_gold(file_): + paras = file_.read().strip().split('\n\n') + golds = [] + for para in paras: + if not para.strip(): continue - line = realign_tagged(detoken_rules, line, sep=sep) - tokens, tags = _parse_line(line, sep) - assert len(tokens) == len(tags) - sentences.append((tokens, tags)) - return sentences - - -def _parse_line(line, sep): - words = [] - tags = [] - for token_str in line.split(): - word, pos = token_str.rsplit(sep, 1) - word = word.replace('', '') - subtokens = EN.tokenize(word) - subtags = pos.split('_') - while len(subtags) < len(subtokens): - subtags.append('NULL') - assert len(subtags) == len(subtokens), [t.string for t in subtokens] - words.append(word) - tags.extend([Tagger.encode_pos(ptb_to_univ(pos)) for pos in subtags]) - tokens = EN.tokenize(' '.join(words)), tags - return tokens - - -def get_tagdict(train_sents): - tagdict = {} - for tokens, tags in train_sents: - for i, tag in enumerate(tags): - if tag == 'NULL': - continue - word = tokens.string(i) - tagdict.setdefault(word, {}).setdefault(tag, 0) - tagdict[word][tag] += 1 - return tagdict + lines = para.strip().split('\n') + raw = lines.pop(0) + gold_toks = lines.pop(0) + tokens = EN.tokenize(raw) + tags = [] + conll_toks = [] + for line in lines: + pieces = line.split() + conll_toks.append((int(pieces[0]), len(pieces[1]), pieces[3])) + for i, token in enumerate(tokens): + if not conll_toks: + tags.append('NULL') + elif token.idx == conll_toks[0][0]: + tags.append(conll_toks[0][2]) + conll_toks.pop(0) + elif token.idx < conll_toks[0]: + tags.append('NULL') + else: + conll_toks.pop(0) + assert len(tags) == len(tokens) + tags = [Tagger.encode_pos(t) for t in tags] + golds.append((tokens, tags)) + return golds def ptb_to_univ(tag):