spaCy/bin/get_freqs.py

94 lines
2.7 KiB
Python
Raw Normal View History

2015-07-14 00:31:32 +00:00
#!/usr/bin/env python
2015-10-15 17:20:35 +00:00
from __future__ import unicode_literals, print_function
2015-07-14 00:31:32 +00:00
import plac
import joblib
from os import path
import os
import bz2
import ujson
from preshed.counter import PreshCounter
from joblib import Parallel, delayed
2015-10-19 01:56:18 +00:00
import io
2015-07-14 00:31:32 +00:00
2015-10-15 17:33:49 +00:00
from spacy.en import English
2015-07-14 00:31:32 +00:00
from spacy.strings import StringStore
2015-10-15 17:20:35 +00:00
from spacy.attrs import ORTH
2015-10-15 17:24:08 +00:00
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab
2015-07-14 00:31:32 +00:00
def iter_comments(loc):
with bz2.BZ2File(loc) as file_:
for line in file_:
yield ujson.loads(line)
def count_freqs(input_loc, output_loc):
2015-10-15 17:20:35 +00:00
print(output_loc)
2015-10-15 17:33:49 +00:00
vocab = English.default_vocab(get_lex_attr=None)
tokenizer = Tokenizer.from_dir(vocab,
path.join(English.default_data_dir(), 'tokenizer'))
2015-07-14 00:31:32 +00:00
counts = PreshCounter()
for json_comment in iter_comments(input_loc):
doc = tokenizer(json_comment['body'])
doc.count_by(ORTH, counts=counts)
2016-03-01 13:10:11 +00:00
with io.open(output_loc, 'w', 'utf8') as file_:
2015-07-14 00:31:32 +00:00
for orth, freq in counts:
2015-10-15 17:20:35 +00:00
string = tokenizer.vocab.strings[orth]
if not string.isspace():
file_.write('%d\t%s\n' % (freq, string))
2015-07-14 00:31:32 +00:00
def parallelize(func, iterator, n_jobs):
Parallel(n_jobs=n_jobs)(delayed(func)(*item) for item in iterator)
def merge_counts(locs, out_loc):
string_map = StringStore()
counts = PreshCounter()
for loc in locs:
2015-10-15 17:20:35 +00:00
with io.open(loc, 'r', encoding='utf8') as file_:
2015-07-14 00:31:32 +00:00
for line in file_:
freq, word = line.strip().split('\t', 1)
orth = string_map[word]
counts.inc(orth, int(freq))
2015-10-15 17:20:35 +00:00
with io.open(out_loc, 'w', encoding='utf8') as file_:
2015-07-22 13:43:06 +00:00
for orth, count in counts:
2015-07-14 00:31:32 +00:00
string = string_map[orth]
file_.write('%d\t%s\n' % (count, string))
@plac.annotations(
input_loc=("Location of input file list"),
2015-07-14 00:31:32 +00:00
freqs_dir=("Directory for frequency files"),
output_loc=("Location for output file"),
n_jobs=("Number of workers", "option", "n", int),
2015-07-22 13:43:06 +00:00
skip_existing=("Skip inputs where an output file exists", "flag", "s", bool),
2015-07-14 00:31:32 +00:00
)
2015-07-22 13:43:06 +00:00
def main(input_loc, freqs_dir, output_loc, n_jobs=2, skip_existing=False):
2015-07-14 00:31:32 +00:00
tasks = []
2015-07-22 13:43:06 +00:00
outputs = []
for input_path in open(input_loc):
input_path = input_path.strip()
2015-07-22 13:43:06 +00:00
if not input_path:
continue
filename = input_path.split('/')[-1]
2015-07-14 00:31:32 +00:00
output_path = path.join(freqs_dir, filename.replace('bz2', 'freq'))
2015-07-22 13:43:06 +00:00
outputs.append(output_path)
if not path.exists(output_path) or not skip_existing:
tasks.append((input_path, output_path))
2015-07-14 00:31:32 +00:00
2015-07-25 19:13:41 +00:00
if tasks:
parallelize(count_freqs, tasks, n_jobs)
2015-07-14 00:31:32 +00:00
2015-10-15 17:20:35 +00:00
print("Merge")
2015-07-22 13:43:06 +00:00
merge_counts(outputs, output_loc)
2015-07-14 00:31:32 +00:00
if __name__ == '__main__':
plac.call(main)