Move fix_deprecated_glove_vectors_loading to deprecated.py

This commit is contained in:
ines 2017-03-15 17:33:29 +01:00
parent 42ba740dde
commit 842782c128
2 changed files with 32 additions and 34 deletions

View File

@ -3,6 +3,7 @@ from sputnik.package_list import (PackageNotFoundException,
CompatiblePackageNotFoundException) CompatiblePackageNotFoundException)
import sputnik import sputnik
from pathlib import Path
from . import about from . import about
@ -30,6 +31,7 @@ def get_package_by_name(name=None, via=None):
(lang.__module__)) (lang.__module__))
from . import util
def read_lang_data(package): def read_lang_data(package):
@ -79,4 +81,31 @@ def detokenize(token_rules, words): # Deprecated?
return positions return positions
def fix_glove_vectors_loading(overrides):
"""Special-case hack for loading the GloVe vectors, to support deprecated
<1.0 stuff. Phase this out once the data is fixed."""
if 'data_dir' in overrides and 'path' not in overrides:
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
if overrides.get('path') is False:
return overrides
if overrides.get('path') in (None, True):
data_path = util.get_data_path()
else:
path = overrides['path']
if isinstance(path, basestring):
path = Path(path)
data_path = path.parent
vec_path = None
if 'add_vectors' not in overrides:
if 'vectors' in overrides:
vec_path = util.match_best_version(overrides['vectors'], None, data_path)
if vec_path is None:
return overrides
else:
vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
if vec_path is not None:
vec_path = vec_path / 'vocab' / 'vec.bin'
if vec_path is not None:
overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
return overrides

View File

@ -2,15 +2,12 @@
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
from os import path from os import path
from pathlib import Path
from ..util import match_best_version
from ..util import get_data_path
from ..language import Language from ..language import Language
from ..lemmatizer import Lemmatizer from ..lemmatizer import Lemmatizer
from ..vocab import Vocab from ..vocab import Vocab
from ..tokenizer import Tokenizer from ..tokenizer import Tokenizer
from ..attrs import LANG from ..attrs import LANG
from ..deprecated import fix_glove_vectors_loading
from .language_data import * from .language_data import *
@ -33,34 +30,6 @@ class English(Language):
def __init__(self, **overrides): def __init__(self, **overrides):
# Make a special-case hack for loading the GloVe vectors, to support # Special-case hack for loading the GloVe vectors, to support <1.0
# deprecated <1.0 stuff. Phase this out once the data is fixed. overrides = fix_glove_vectors_loading(overrides)
overrides = _fix_deprecated_glove_vectors_loading(overrides)
Language.__init__(self, **overrides) Language.__init__(self, **overrides)
def _fix_deprecated_glove_vectors_loading(overrides):
if 'data_dir' in overrides and 'path' not in overrides:
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
if overrides.get('path') is False:
return overrides
if overrides.get('path') in (None, True):
data_path = get_data_path()
else:
path = overrides['path']
if isinstance(path, basestring):
path = Path(path)
data_path = path.parent
vec_path = None
if 'add_vectors' not in overrides:
if 'vectors' in overrides:
vec_path = match_best_version(overrides['vectors'], None, data_path)
if vec_path is None:
return overrides
else:
vec_path = match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
if vec_path is not None:
vec_path = vec_path / 'vocab' / 'vec.bin'
if vec_path is not None:
overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
return overrides