From ea4a41c8fb573e3034daacdae117f39b18e82842 Mon Sep 17 00:00:00 2001 From: ines Date: Fri, 27 Oct 2017 14:39:09 +0200 Subject: [PATCH] Tidy up util and helpers --- spacy/compat.py | 12 ++--- spacy/deprecated.py | 9 ++-- spacy/glossary.py | 1 - spacy/tokens/printers.py | 9 ++-- spacy/tokens/underscore.py | 4 ++ spacy/util.py | 99 ++++++++++++++++++++------------------ 6 files changed, 71 insertions(+), 63 deletions(-) diff --git a/spacy/compat.py b/spacy/compat.py index 81243ce1b..260c956fb 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -87,15 +87,15 @@ def symlink_to(orig, dest): def is_config(python2=None, python3=None, windows=None, linux=None, osx=None): - return ((python2 == None or python2 == is_python2) and - (python3 == None or python3 == is_python3) and - (windows == None or windows == is_windows) and - (linux == None or linux == is_linux) and - (osx == None or osx == is_osx)) + return ((python2 is None or python2 == is_python2) and + (python3 is None or python3 == is_python3) and + (windows is None or windows == is_windows) and + (linux is None or linux == is_linux) and + (osx is None or osx == is_osx)) def normalize_string_keys(old): - '''Given a dictionary, make sure keys are unicode strings, not bytes.''' + """Given a dictionary, make sure keys are unicode strings, not bytes.""" new = {} for key, value in old.items(): if isinstance(key, bytes_): diff --git a/spacy/deprecated.py b/spacy/deprecated.py index ad52bfe24..a1143474a 100644 --- a/spacy/deprecated.py +++ b/spacy/deprecated.py @@ -24,7 +24,7 @@ def depr_model_download(lang): def resolve_load_name(name, **overrides): - """Resolve model loading if deprecated path kwarg is specified in overrides. + """Resolve model loading if deprecated path kwarg in overrides. name (unicode): Name of model to load. **overrides: Overrides specified in spacy.load(). @@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides): """ if overrides.get('path') not in (None, False, True): name = overrides.get('path') - prints("To load a model from a path, you can now use the first argument. " - "The model meta is used to load the required Language class.", - "OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')", + prints("To load a model from a path, you can now use the first " + "argument. The model meta is used to load the Language class.", + "OLD: spacy.load('en', path='/some/path')", + "NEW: spacy.load('/some/path')", title="Warning: deprecated argument 'path'") return name diff --git a/spacy/glossary.py b/spacy/glossary.py index fd74d85e7..78e61f8a7 100644 --- a/spacy/glossary.py +++ b/spacy/glossary.py @@ -264,7 +264,6 @@ GLOSSARY = { 'nk': 'noun kernel element', 'nmc': 'numerical component', 'oa': 'accusative object', - 'oa': 'second accusative object', 'oc': 'clausal object', 'og': 'genitive object', 'op': 'prepositional object', diff --git a/spacy/tokens/printers.py b/spacy/tokens/printers.py index 4bc7099d7..92b2cd84c 100644 --- a/spacy/tokens/printers.py +++ b/spacy/tokens/printers.py @@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False): def parse_tree(doc, light=False, flat=False): - """Makes a copy of the doc, then construct a syntactic parse tree, similar to - the one used in displaCy. Generates the POS tree for all sentences in a doc. + """Make a copy of the doc and construct a syntactic parse tree similar to + displaCy. Generates the POS tree for all sentences in a doc. doc (Doc): The doc for parsing. RETURNS (dict): The parse tree. @@ -66,8 +66,9 @@ def parse_tree(doc, light=False, flat=False): 'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB', 'POS_fine': 'VBD', 'lemma': 'eat'} """ - doc_clone = Doc(doc.vocab, words=[w.text for w in doc]) + doc_clone = Doc(doc.vocab, words=[w.text for w in doc]) doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE], doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE])) merge_ents(doc_clone) # merge the entities into single tokens first - return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents] + return [POS_tree(sent.root, light=light, flat=flat) + for sent in doc_clone.sents] diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py index 6e782647b..d80f50685 100644 --- a/spacy/tokens/underscore.py +++ b/spacy/tokens/underscore.py @@ -1,5 +1,9 @@ +# coding: utf8 +from __future__ import unicode_literals + import functools + class Underscore(object): doc_extensions = {} span_extensions = {} diff --git a/spacy/util.py b/spacy/util.py index ca5a40f97..a45d43c47 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -10,25 +10,27 @@ from pathlib import Path import sys import textwrap import random -import numpy -import io -import dill from collections import OrderedDict from thinc.neural._classes.model import Model import functools +from .symbols import ORTH +from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ +from .compat import import_file + import msgpack import msgpack_numpy msgpack_numpy.patch() -import ujson - -from .symbols import ORTH -from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ -from .compat import copy_array, normalize_string_keys, getattr_, import_file LANGUAGES = {} _data_path = Path(__file__).parent / 'data' +_PRINT_ENV = False + + +def set_env_log(value): + global _PRINT_ENV + _PRINT_ENV = value def get_lang_class(lang): @@ -38,11 +40,12 @@ def get_lang_class(lang): RETURNS (Language): Language class. """ global LANGUAGES - if not lang in LANGUAGES: + if lang not in LANGUAGES: try: module = importlib.import_module('.lang.%s' % lang, 'spacy') except ImportError: - raise ImportError("Can't import language %s from spacy.lang." %lang) + msg = "Can't import language %s from spacy.lang." + raise ImportError(msg % lang) LANGUAGES[lang] = getattr(module, module.__all__[0]) return LANGUAGES[lang] @@ -100,14 +103,14 @@ def load_model(name, **overrides): data_path = get_data_path() if not data_path or not data_path.exists(): raise IOError("Can't find spaCy data path: %s" % path2str(data_path)) - if isinstance(name, basestring_): - if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut + if isinstance(name, basestring_): # in data dir / shortcut + if name in set([d.name for d in data_path.iterdir()]): return load_model_from_link(name, **overrides) - if is_package(name): # installed as package + if is_package(name): # installed as package return load_model_from_package(name, **overrides) - if Path(name).exists(): # path to model data directory + if Path(name).exists(): # path to model data directory return load_model_from_path(Path(name), **overrides) - elif hasattr(name, 'exists'): # Path or Path-like to model data + elif hasattr(name, 'exists'): # Path or Path-like to model data return load_model_from_path(name, **overrides) raise IOError("Can't find model '%s'" % name) @@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides): except AttributeError: raise IOError( "Cant' load '%s'. If you're using a shortcut link, make sure it " - "points to a valid model package (not just a data directory)." % name) + "points to a valid package (not just a data directory)." % name) return cls.load(**overrides) @@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides): data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version']) data_path = model_path / data_dir if not model_path.exists(): - raise ValueError("Can't find model directory: %s" % path2str(data_path)) + msg = "Can't find model directory: %s" + raise ValueError(msg % path2str(data_path)) return load_model_from_path(data_path, meta, **overrides) @@ -176,14 +180,16 @@ def get_model_meta(path): """ model_path = ensure_path(path) if not model_path.exists(): - raise ValueError("Can't find model directory: %s" % path2str(model_path)) + msg = "Can't find model directory: %s" + raise ValueError(msg % path2str(model_path)) meta_path = model_path / 'meta.json' if not meta_path.is_file(): raise IOError("Could not read meta.json from %s" % meta_path) meta = read_json(meta_path) for setting in ['lang', 'name', 'version']: if setting not in meta or not meta[setting]: - raise ValueError("No valid '%s' setting found in model meta.json" % setting) + msg = "No valid '%s' setting found in model meta.json" + raise ValueError(msg % setting) return meta @@ -240,7 +246,7 @@ def get_async(stream, numpy_array): return numpy_array else: array = cupy.ndarray(numpy_array.shape, order='C', - dtype=numpy_array.dtype) + dtype=numpy_array.dtype) array.set(numpy_array, stream=stream) return array @@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000): raise StopIteration -_PRINT_ENV = False -def set_env_log(value): - global _PRINT_ENV - _PRINT_ENV = value - - def env_opt(name, default=None): if type(default) is float: type_convert = float @@ -305,17 +305,20 @@ def read_regex(path): path = ensure_path(path) with path.open() as file_: entries = file_.read().split('\n') - expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()]) + expression = '|'.join(['^' + re.escape(piece) + for piece in entries if piece.strip()]) return re.compile(expression) def compile_prefix_regex(entries): if '(' in entries: # Handle deprecated data - expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()]) + expression = '|'.join(['^' + re.escape(piece) + for piece in entries if piece.strip()]) return re.compile(expression) else: - expression = '|'.join(['^' + piece for piece in entries if piece.strip()]) + expression = '|'.join(['^' + piece + for piece in entries if piece.strip()]) return re.compile(expression) @@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts): exc = dict(base_exceptions) for additions in addition_dicts: for orth, token_attrs in additions.items(): - if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs): - msg = "Invalid value for ORTH in exception: key='%s', orths='%s'" + if not all(isinstance(attr[ORTH], unicode_) + for attr in token_attrs): + msg = "Invalid ORTH value in exception: key='%s', orths='%s'" raise ValueError(msg % (orth, token_attrs)) described_orth = ''.join(attr[ORTH] for attr in token_attrs) if orth != described_orth: - raise ValueError("Invalid tokenizer exception: ORTH values " - "combined don't match original string. " - "key='%s', orths='%s'" % (orth, described_orth)) - # overlap = set(exc.keys()).intersection(set(additions)) - # assert not overlap, overlap + msg = ("Invalid tokenizer exception: ORTH values combined " + "don't match original string. key='%s', orths='%s'") + raise ValueError(msg % (orth, described_orth)) exc.update(additions) exc = expand_exc(exc, "'", "’") return exc @@ -401,17 +403,15 @@ def normalize_slice(length, start, stop, step=None): raise ValueError("Stepped slices not supported in Span objects." "Try: list(tokens)[start:stop:step] instead.") if start is None: - start = 0 + start = 0 elif start < 0: - start += length + start += length start = min(length, max(0, start)) - if stop is None: - stop = length + stop = length elif stop < 0: - stop += length + stop += length stop = min(length, max(start, stop)) - assert 0 <= start <= stop <= length return start, stop @@ -428,7 +428,7 @@ def compounding(start, stop, compound): >>> assert next(sizes) == 1.5 * 1.5 """ def clip(value): - return max(value, stop) if (start>stop) else min(value, stop) + return max(value, stop) if (start > stop) else min(value, stop) curr = float(start) while True: yield clip(curr) @@ -438,7 +438,7 @@ def compounding(start, stop, compound): def decaying(start, stop, decay): """Yield an infinite series of linearly decaying values.""" def clip(value): - return max(value, stop) if (start>stop) else min(value, stop) + return max(value, stop) if (start > stop) else min(value, stop) nr_upd = 1. while True: yield clip(start * 1./(1. + decay * nr_upd)) @@ -530,17 +530,19 @@ def print_markdown(data, title=None): if isinstance(data, dict): data = list(data.items()) - markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)] + markdown = ["* **{}:** {}".format(l, unicode_(v)) + for l, v in data if not excl_value(v)] if title: print("\n## {}".format(title)) print('\n{}\n'.format('\n'.join(markdown))) def prints(*texts, **kwargs): - """Print formatted message (manual ANSI escape sequences to avoid dependency) + """Print formatted message (manual ANSI escape sequences to avoid + dependency) *texts (unicode): Texts to print. Each argument is rendered as paragraph. - **kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit. + **kwargs: 'title' becomes coloured headline. exits=True performs sys exit. """ exits = kwargs.get('exits', None) title = kwargs.get('title', None) @@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4): def minify_html(html): """Perform a template-specific, rudimentary HTML minification for displaCy. - Disclaimer: NOT a general-purpose solution, only removes indentation/newlines. + Disclaimer: NOT a general-purpose solution, only removes indentation and + newlines. html (unicode): Markup to minify. RETURNS (unicode): "Minified" HTML.