diff --git a/bin/init_model.py b/bin/init_model.py index d40e7813d..2591ada50 100644 --- a/bin/init_model.py +++ b/bin/init_model.py @@ -18,6 +18,7 @@ Requires: from __future__ import unicode_literals from ast import literal_eval +import math import plac from pathlib import Path @@ -91,15 +92,15 @@ def _read_probs(loc): def _read_freqs(loc): counts = PreshCounter() total = 0 - for line in open(loc): + for i, line in enumerate(loc.open()): freq, doc_freq, key = line.split('\t', 2) freq = int(freq) - counts[hash_string(key)] = freq + counts.inc(i+1, freq) total += freq counts.smooth() log_total = math.log(total) probs = {} - for line in open(loc): + for line in loc.open(): freq, doc_freq, key = line.split('\t', 2) if int(doc_freq) >= 2 and int(freq) >= 5 and len(key) < 200: word = literal_eval(key)