diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 8a7d12555..eab6c044e 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -21,6 +21,7 @@ from ..lexeme cimport Lexeme from .spans cimport Span from .token cimport Token from ..serialize.bits cimport BitArray +from ..util import normalize_slice DEF PADDING = 5 @@ -81,20 +82,14 @@ cdef class Doc: self._vector = None def __getitem__(self, object i): - """Get a token. + """Get a Token or a Span from the Doc. Returns: - token (Token): + token (Token) or span (Span): """ if isinstance(i, slice): - if i.step is not None: - raise ValueError("Stepped slices not supported in Span objects." - "Try: list(doc)[start:stop:step] instead.") - if i.start is None: - i = slice(0, i.stop) - if i.stop is None: - i = slice(i.start, len(self)) - return Span(self, i.start, i.stop, label=0) + start, stop = normalize_slice(len(self), i.start, i.stop, i.step) + return Span(self, start, stop, label=0) if i < 0: i = self.length + i diff --git a/spacy/tokens/spans.pyx b/spacy/tokens/spans.pyx index c39f8976c..e8d2f2e59 100644 --- a/spacy/tokens/spans.pyx +++ b/spacy/tokens/spans.pyx @@ -9,16 +9,16 @@ from ..structs cimport TokenC, LexemeC from ..typedefs cimport flags_t, attr_t from ..attrs cimport attr_id_t from ..parts_of_speech cimport univ_pos_t +from ..util import normalize_slice cdef class Span: """A slice from a Doc object.""" def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None, vector_norm=None): - if start < 0: - start = tokens.length - start - if end < 0: - end = tokens.length - end + if not (0 <= start <= end <= len(tokens)): + raise IndexError + self.doc = tokens self.start = start self.end = end @@ -46,7 +46,13 @@ cdef class Span: return 0 return self.end - self.start - def __getitem__(self, int i): + def __getitem__(self, object i): + if isinstance(i, slice): + start, end = normalize_slice(len(self), i.start, i.stop, i.step) + start += self.start + end += self.start + return Span(self.doc, start, end) + if i < 0: return self.doc[self.end + i] else: diff --git a/spacy/util.py b/spacy/util.py index 93a67c66e..849a3e219 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -7,6 +7,26 @@ from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE DATA_DIR = path.join(path.dirname(__file__), '..', 'data') +def normalize_slice(length, start, stop, step=None): + if not (step is None or step == 1): + raise ValueError("Stepped slices not supported in Span objects." + "Try: list(tokens)[start:stop:step] instead.") + if start is None: + start = 0 + elif start < 0: + start += length + start = min(length, max(0, start)) + + if stop is None: + stop = length + elif stop < 0: + stop += length + stop = min(length, max(start, stop)) + + assert 0 <= start <= stop <= length + return start, stop + + def utf8open(loc, mode='r'): return io.open(loc, mode, encoding='utf8') diff --git a/tests/tokens/test_tokens_api.py b/tests/tokens/test_tokens_api.py index e1238373f..675f00235 100644 --- a/tests/tokens/test_tokens_api.py +++ b/tests/tokens/test_tokens_api.py @@ -12,6 +12,72 @@ def test_getitem(EN): with pytest.raises(IndexError): tokens[len(tokens)] + def to_str(span): + return '/'.join(token.orth_ for token in span) + + span = tokens[1:1] + assert not to_str(span) + span = tokens[1:4] + assert to_str(span) == 'it/back/!' + span = tokens[1:4:1] + assert to_str(span) == 'it/back/!' + with pytest.raises(ValueError): + tokens[1:4:2] + with pytest.raises(ValueError): + tokens[1:4:-1] + + span = tokens[-3:6] + assert to_str(span) == 'He/pleaded' + span = tokens[4:-1] + assert to_str(span) == 'He/pleaded' + span = tokens[-5:-3] + assert to_str(span) == 'back/!' + span = tokens[5:4] + assert span.start == span.end == 5 and not to_str(span) + span = tokens[4:-3] + assert span.start == span.end == 4 and not to_str(span) + + span = tokens[:] + assert to_str(span) == 'Give/it/back/!/He/pleaded/.' + span = tokens[4:] + assert to_str(span) == 'He/pleaded/.' + span = tokens[:4] + assert to_str(span) == 'Give/it/back/!' + span = tokens[:-3] + assert to_str(span) == 'Give/it/back/!' + span = tokens[-3:] + assert to_str(span) == 'He/pleaded/.' + + span = tokens[4:50] + assert to_str(span) == 'He/pleaded/.' + span = tokens[-50:4] + assert to_str(span) == 'Give/it/back/!' + span = tokens[-50:-40] + assert span.start == span.end == 0 and not to_str(span) + span = tokens[40:50] + assert span.start == span.end == 7 and not to_str(span) + + span = tokens[1:4] + assert span[0].orth_ == 'it' + subspan = span[:] + assert to_str(subspan) == 'it/back/!' + subspan = span[:2] + assert to_str(subspan) == 'it/back' + subspan = span[1:] + assert to_str(subspan) == 'back/!' + subspan = span[:-1] + assert to_str(subspan) == 'it/back' + subspan = span[-2:] + assert to_str(subspan) == 'back/!' + subspan = span[1:2] + assert to_str(subspan) == 'back' + subspan = span[-2:-1] + assert to_str(subspan) == 'back' + subspan = span[-50:50] + assert to_str(subspan) == 'it/back/!' + subspan = span[50:-50] + assert subspan.start == subspan.end == 4 and not to_str(subspan) + @pytest.mark.models def test_serialize(EN):