diff --git a/spacy/util.py b/spacy/util.py index 8c9ea319c..2b6f50a6b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -2,23 +2,74 @@ import os import io import json import re - -from sputnik import Sputnik +import os.path +from contextlib import contextmanager from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE -def get_package(name=None, data_path=None): - if data_path is None: - if os.environ.get('SPACY_DATA'): - data_path = os.environ.get('SPACY_DATA') - else: - data_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'data')) +def local_path(subdir): + return os.path.abspath(os.path.join(os.path.dirname(__file__), 'data')) - sputnik = Sputnik('spacy', '0.100.0') # TODO: retrieve version - pool = sputnik.pool(data_path) - return pool.get(name or 'en_default') + +class MockPackage(object): + @classmethod + def create_or_return(cls, me_or_arg): + return me_or_arg if isinstance(me_or_arg, cls) else me_or_arg + + def __init__(self, data_path=None): + if data_path is None: + data_path = local_path('data') + self.name = None + self.data_path = data_path + self._root = self.data_path + + def get(self, key): + pass + + def has_file(self, *path_parts): + return os.path.exists(os.path.join(self._root, *path_parts)) + + def file_path(self, *path_parts, **kwargs): + return os.path.join(self._root, *path_parts) + + def dir_path(self, *path_parts, **kwargs): + return os.path.join(self._root, *path_parts) + + def load_utf8(self, func, *path_parts, **kwargs): + if kwargs.get('require', True): + with io.open(self.file_path(os.path.join(*path_parts)), + mode='r', encoding='utf8') as f: + return func(f) + else: + return None + + @contextmanager + def open(self, path_parts, default=IOError): + if isinstance(default, Exception): + raise default + + # Enter + file_ = io.open(self.file_path(os.path.join(*path_parts)), + mode='r', encoding='utf8') + yield file_ + # Exit + file_.close() + + + +def get_package(name=None, data_path=None): + return MockPackage(data_path) + #if data_path is None: + # if os.environ.get('SPACY_DATA'): + # data_path = os.environ.get('SPACY_DATA') + # else: + # data_path = os.path.abspath( + # os.path.join(os.path.dirname(__file__), 'data')) + + #sputnik = Sputnik('spacy', '0.100.0') # TODO: retrieve version + #pool = sputnik.pool(data_path) + #return pool.get(name or 'en_default') def normalize_slice(length, start, stop, step=None):