diff --git a/bin/init_model.py b/bin/init_model.py index c29a10107..5f06086df 100644 --- a/bin/init_model.py +++ b/bin/init_model.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals from ast import literal_eval import math +import gzip import plac from pathlib import Path @@ -78,7 +79,7 @@ def _read_clusters(loc): def _read_probs(loc): if not loc.exists(): - print("Warning: Probabilities file not found") + print("Probabilities file not found. Trying freqs.") return {}, 0.0 probs = {} for i, line in enumerate(codecs.open(str(loc), 'r', 'utf8')): @@ -94,7 +95,11 @@ def _read_freqs(loc, max_length=100, min_doc_freq=5, min_freq=100): return {}, 0.0 counts = PreshCounter() total = 0 - for i, line in enumerate(loc.open()): + if str(loc).endswith('gz'): + file_ = gzip.open(str(loc)) + else: + file_ = loc.open() + for i, line in enumerate(file_): freq, doc_freq, key = line.split('\t', 2) freq = int(freq) counts.inc(i+1, freq)