diff --git a/spacy/errors.py b/spacy/errors.py index b60fe690a..a557be2e8 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -245,6 +245,8 @@ class Errors(object): "the meta.json. Vector names are required to avoid issue #1660.") E093 = ("token.ent_iob values make invalid sequence: I without B\n{seq}") E094 = ("Error reading line {line_num} in vectors file {loc}.") + E095 = ("Can't write to frozen dictionary. This is likely an internal " + "error. Are you writing to a default function argument?") @add_codes diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 06f6a3d30..d9db0916b 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from ..util import get_doc from ...tokens import Doc from ...vocab import Vocab +from ...attrs import LEMMA import pytest import numpy @@ -178,6 +179,26 @@ def test_doc_api_merge_hang(en_tokenizer): doc.merge(8, 32, tag='', lemma='', ent_type='ORG') +def test_doc_api_retokenizer(en_tokenizer): + doc = en_tokenizer("WKRO played songs by the beach boys all night") + with doc.retokenize() as retokenizer: + retokenizer.merge(doc[4:7]) + assert len(doc) == 7 + assert doc[4].text == 'the beach boys' + + +def test_doc_api_retokenizer_attrs(en_tokenizer): + doc = en_tokenizer("WKRO played songs by the beach boys all night") + # test both string and integer attributes and values + attrs = {LEMMA: 'boys', 'ENT_TYPE': doc.vocab.strings['ORG']} + with doc.retokenize() as retokenizer: + retokenizer.merge(doc[4:7], attrs=attrs) + assert len(doc) == 7 + assert doc[4].text == 'the beach boys' + assert doc[4].lemma_ == 'boys' + assert doc[4].ent_type_ == 'ORG' + + def test_doc_api_sents_empty_string(en_tokenizer): doc = en_tokenizer("") doc.is_parsed = True diff --git a/spacy/tokens/_retokenize.pyx b/spacy/tokens/_retokenize.pyx index 00f724ed6..b405dd000 100644 --- a/spacy/tokens/_retokenize.pyx +++ b/spacy/tokens/_retokenize.pyx @@ -11,11 +11,13 @@ from .span cimport Span from .token cimport Token from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..structs cimport LexemeC, TokenC -from ..attrs cimport * +from ..attrs cimport TAG +from ..attrs import intify_attrs +from ..util import SimpleFrozenDict cdef class Retokenizer: - '''Helper class for doc.retokenize() context manager.''' + """Helper class for doc.retokenize() context manager.""" cdef Doc doc cdef list merges cdef list splits @@ -24,14 +26,18 @@ cdef class Retokenizer: self.merges = [] self.splits = [] - def merge(self, Span span, attrs=None): - '''Mark a span for merging. The attrs will be applied to the resulting - token.''' + def merge(self, Span span, attrs=SimpleFrozenDict()): + """Mark a span for merging. The attrs will be applied to the resulting + token. + """ + attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) self.merges.append((span.start_char, span.end_char, attrs)) - def split(self, Token token, orths, attrs=None): - '''Mark a Token for splitting, into the specified orths. The attrs - will be applied to each subtoken.''' + def split(self, Token token, orths, attrs=SimpleFrozenDict()): + """Mark a Token for splitting, into the specified orths. The attrs + will be applied to each subtoken. + """ + attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) self.splits.append((token.start_char, orths, attrs)) def __enter__(self): @@ -125,5 +131,3 @@ def _merge(Doc doc, int start, int end, attributes): # Clear the cached Python objects # Return the merged Python object return doc[start] - - diff --git a/spacy/util.py b/spacy/util.py index cc3e0d9ee..b80142c38 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -635,3 +635,18 @@ def use_gpu(gpu_id): def fix_random_seed(seed=0): random.seed(seed) numpy.random.seed(seed) + + +class SimpleFrozenDict(dict): + """Simplified implementation of a frozen dict, mainly used as default + function or method argument (for arguments that should default to empty + dictionary). Will raise an error if user or spaCy attempts to add to dict. + """ + def __setitem__(self, key, value): + raise NotImplementedError(Errors.E095) + + def pop(self, key, default=None): + raise NotImplementedError(Errors.E095) + + def update(self, other): + raise NotImplementedError(Errors.E095)