mirror of https://github.com/explosion/spaCy.git
parent
d8573ee715
commit
3f52e12335
|
@ -13,23 +13,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, directory, min_freq=10):
|
||||
def __init__(self, directory, nlp):
|
||||
self.directory = directory
|
||||
self.counts = PreshCounter()
|
||||
self.strings = {}
|
||||
self.min_freq = min_freq
|
||||
|
||||
def count_doc(self, doc):
|
||||
# Get counts for this document
|
||||
for word in doc:
|
||||
self.counts.inc(word.orth, 1)
|
||||
return len(doc)
|
||||
self.nlp = nlp
|
||||
|
||||
def __iter__(self):
|
||||
for text_loc in iter_dir(self.directory):
|
||||
with text_loc.open("r", encoding="utf-8") as file_:
|
||||
text = file_.read()
|
||||
yield text
|
||||
|
||||
# This is to keep the input to the blank model (which doesn't
|
||||
# sentencize) from being too long. It works particularly well with
|
||||
# the output of [WikiExtractor](https://github.com/attardi/wikiextractor)
|
||||
paragraphs = text.split('\n\n')
|
||||
for par in paragraphs:
|
||||
yield [word.orth_ for word in self.nlp(par)]
|
||||
|
||||
|
||||
def iter_dir(loc):
|
||||
|
@ -62,12 +60,15 @@ def main(
|
|||
window=5,
|
||||
size=128,
|
||||
min_count=10,
|
||||
nr_iter=2,
|
||||
nr_iter=5,
|
||||
):
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
|
||||
)
|
||||
nlp = spacy.blank(lang)
|
||||
corpus = Corpus(in_dir, nlp)
|
||||
model = Word2Vec(
|
||||
sentences=corpus,
|
||||
size=size,
|
||||
window=window,
|
||||
min_count=min_count,
|
||||
|
@ -75,33 +76,7 @@ def main(
|
|||
sample=1e-5,
|
||||
negative=negative,
|
||||
)
|
||||
nlp = spacy.blank(lang)
|
||||
corpus = Corpus(in_dir)
|
||||
total_words = 0
|
||||
total_sents = 0
|
||||
for text_no, text_loc in enumerate(iter_dir(corpus.directory)):
|
||||
with text_loc.open("r", encoding="utf-8") as file_:
|
||||
text = file_.read()
|
||||
total_sents += text.count("\n")
|
||||
doc = nlp(text)
|
||||
total_words += corpus.count_doc(doc)
|
||||
logger.info(
|
||||
"PROGRESS: at batch #%i, processed %i words, keeping %i word types",
|
||||
text_no,
|
||||
total_words,
|
||||
len(corpus.strings),
|
||||
)
|
||||
model.corpus_count = total_sents
|
||||
model.raw_vocab = defaultdict(int)
|
||||
for orth, freq in corpus.counts:
|
||||
if freq >= min_count:
|
||||
model.raw_vocab[nlp.vocab.strings[orth]] = freq
|
||||
model.scale_vocab()
|
||||
model.finalize_vocab()
|
||||
model.iter = nr_iter
|
||||
model.train(corpus)
|
||||
model.save(out_loc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
plac.call(main)
|
||||
|
|
Loading…
Reference in New Issue