Merge pull request #126 from tomtung/master

Improve slicing support for both Doc and Span
This commit is contained in:
Matthew Honnibal 2015-10-10 14:14:57 +11:00
commit dc393a5f1d
4 changed files with 102 additions and 15 deletions

View File

@ -21,6 +21,7 @@ from ..lexeme cimport Lexeme
from .spans cimport Span from .spans cimport Span
from .token cimport Token from .token cimport Token
from ..serialize.bits cimport BitArray from ..serialize.bits cimport BitArray
from ..util import normalize_slice
DEF PADDING = 5 DEF PADDING = 5
@ -81,20 +82,14 @@ cdef class Doc:
self._vector = None self._vector = None
def __getitem__(self, object i): def __getitem__(self, object i):
"""Get a token. """Get a Token or a Span from the Doc.
Returns: Returns:
token (Token): token (Token) or span (Span):
""" """
if isinstance(i, slice): if isinstance(i, slice):
if i.step is not None: start, stop = normalize_slice(len(self), i.start, i.stop, i.step)
raise ValueError("Stepped slices not supported in Span objects." return Span(self, start, stop, label=0)
"Try: list(doc)[start:stop:step] instead.")
if i.start is None:
i = slice(0, i.stop)
if i.stop is None:
i = slice(i.start, len(self))
return Span(self, i.start, i.stop, label=0)
if i < 0: if i < 0:
i = self.length + i i = self.length + i

View File

@ -9,16 +9,16 @@ from ..structs cimport TokenC, LexemeC
from ..typedefs cimport flags_t, attr_t from ..typedefs cimport flags_t, attr_t
from ..attrs cimport attr_id_t from ..attrs cimport attr_id_t
from ..parts_of_speech cimport univ_pos_t from ..parts_of_speech cimport univ_pos_t
from ..util import normalize_slice
cdef class Span: cdef class Span:
"""A slice from a Doc object.""" """A slice from a Doc object."""
def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None, def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None,
vector_norm=None): vector_norm=None):
if start < 0: if not (0 <= start <= end <= len(tokens)):
start = tokens.length - start raise IndexError
if end < 0:
end = tokens.length - end
self.doc = tokens self.doc = tokens
self.start = start self.start = start
self.end = end self.end = end
@ -46,7 +46,13 @@ cdef class Span:
return 0 return 0
return self.end - self.start return self.end - self.start
def __getitem__(self, int i): def __getitem__(self, object i):
if isinstance(i, slice):
start, end = normalize_slice(len(self), i.start, i.stop, i.step)
start += self.start
end += self.start
return Span(self.doc, start, end)
if i < 0: if i < 0:
return self.doc[self.end + i] return self.doc[self.end + i]
else: else:

View File

@ -7,6 +7,26 @@ from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
DATA_DIR = path.join(path.dirname(__file__), '..', 'data') DATA_DIR = path.join(path.dirname(__file__), '..', 'data')
def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.")
if start is None:
start = 0
elif start < 0:
start += length
start = min(length, max(0, start))
if stop is None:
stop = length
elif stop < 0:
stop += length
stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length
return start, stop
def utf8open(loc, mode='r'): def utf8open(loc, mode='r'):
return io.open(loc, mode, encoding='utf8') return io.open(loc, mode, encoding='utf8')

View File

@ -12,6 +12,72 @@ def test_getitem(EN):
with pytest.raises(IndexError): with pytest.raises(IndexError):
tokens[len(tokens)] tokens[len(tokens)]
def to_str(span):
return '/'.join(token.orth_ for token in span)
span = tokens[1:1]
assert not to_str(span)
span = tokens[1:4]
assert to_str(span) == 'it/back/!'
span = tokens[1:4:1]
assert to_str(span) == 'it/back/!'
with pytest.raises(ValueError):
tokens[1:4:2]
with pytest.raises(ValueError):
tokens[1:4:-1]
span = tokens[-3:6]
assert to_str(span) == 'He/pleaded'
span = tokens[4:-1]
assert to_str(span) == 'He/pleaded'
span = tokens[-5:-3]
assert to_str(span) == 'back/!'
span = tokens[5:4]
assert span.start == span.end == 5 and not to_str(span)
span = tokens[4:-3]
assert span.start == span.end == 4 and not to_str(span)
span = tokens[:]
assert to_str(span) == 'Give/it/back/!/He/pleaded/.'
span = tokens[4:]
assert to_str(span) == 'He/pleaded/.'
span = tokens[:4]
assert to_str(span) == 'Give/it/back/!'
span = tokens[:-3]
assert to_str(span) == 'Give/it/back/!'
span = tokens[-3:]
assert to_str(span) == 'He/pleaded/.'
span = tokens[4:50]
assert to_str(span) == 'He/pleaded/.'
span = tokens[-50:4]
assert to_str(span) == 'Give/it/back/!'
span = tokens[-50:-40]
assert span.start == span.end == 0 and not to_str(span)
span = tokens[40:50]
assert span.start == span.end == 7 and not to_str(span)
span = tokens[1:4]
assert span[0].orth_ == 'it'
subspan = span[:]
assert to_str(subspan) == 'it/back/!'
subspan = span[:2]
assert to_str(subspan) == 'it/back'
subspan = span[1:]
assert to_str(subspan) == 'back/!'
subspan = span[:-1]
assert to_str(subspan) == 'it/back'
subspan = span[-2:]
assert to_str(subspan) == 'back/!'
subspan = span[1:2]
assert to_str(subspan) == 'back'
subspan = span[-2:-1]
assert to_str(subspan) == 'back'
subspan = span[-50:50]
assert to_str(subspan) == 'it/back/!'
subspan = span[50:-50]
assert subspan.start == subspan.end == 4 and not to_str(subspan)
@pytest.mark.models @pytest.mark.models
def test_serialize(EN): def test_serialize(EN):