diff --git a/setup.py b/setup.py
index 14f8486ca..fb659bcb0 100755
--- a/setup.py
+++ b/setup.py
@@ -55,6 +55,8 @@ MOD_NAMES = [
"spacy.tokens.doc",
"spacy.tokens.span",
"spacy.tokens.token",
+ "spacy.tokens.span_group",
+ "spacy.tokens.graph",
"spacy.tokens.morphanalysis",
"spacy.tokens._retokenize",
"spacy.matcher.matcher",
@@ -68,7 +70,7 @@ COMPILE_OPTIONS = {
"mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
"other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
}
-LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": []}
+LINK_OPTIONS = {"msvc": ["-std=c++11"], "mingw32": ["-std=c++11"], "other": []}
COMPILER_DIRECTIVES = {
"language_level": -3,
"embedsignature": True,
@@ -201,7 +203,7 @@ def setup_package():
ext_modules = []
for name in MOD_NAMES:
mod_path = name.replace(".", "/") + ".pyx"
- ext = Extension(name, [mod_path], language="c++")
+ ext = Extension(name, [mod_path], language="c++", extra_compile_args=["-std=c++11"])
ext_modules.append(ext)
print("Cythonizing sources")
ext_modules = cythonize(ext_modules, compiler_directives=COMPILER_DIRECTIVES)
diff --git a/spacy/structs.pxd b/spacy/structs.pxd
index 4a51bc9e0..86d5b67ed 100644
--- a/spacy/structs.pxd
+++ b/spacy/structs.pxd
@@ -1,5 +1,7 @@
from libc.stdint cimport uint8_t, uint32_t, int32_t, uint64_t
from libcpp.vector cimport vector
+from libcpp.unordered_set cimport unordered_set
+from libcpp.unordered_map cimport unordered_map
from libc.stdint cimport int32_t, int64_t
from .typedefs cimport flags_t, attr_t, hash_t
@@ -91,3 +93,22 @@ cdef struct AliasC:
# Prior probability P(entity|alias) - should sum up to (at most) 1.
vector[float] probs
+
+
+cdef struct EdgeC:
+ hash_t label
+ int32_t head
+ int32_t tail
+
+
+cdef struct GraphC:
+ vector[vector[int32_t]] nodes
+ vector[EdgeC] edges
+ vector[float] weights
+ vector[int] n_heads
+ vector[int] n_tails
+ vector[int] first_head
+ vector[int] first_tail
+ unordered_set[int]* roots
+ unordered_map[hash_t, int]* node_map
+ unordered_map[hash_t, int]* edge_map
diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py
index fcadca061..74b8d825e 100644
--- a/spacy/tests/doc/test_doc_api.py
+++ b/spacy/tests/doc/test_doc_api.py
@@ -631,3 +631,24 @@ def test_doc_set_ents_invalid_spans(en_tokenizer):
retokenizer.merge(span)
with pytest.raises(IndexError):
doc.ents = spans
+
+
+def test_span_groups(en_tokenizer):
+ doc = en_tokenizer("Some text about Colombia and the Czech Republic")
+ doc.spans["hi"] = [Span(doc, 3, 4, label="bye")]
+ assert "hi" in doc.spans
+ assert "bye" not in doc.spans
+ assert len(doc.spans["hi"]) == 1
+ assert doc.spans["hi"][0].label_ == "bye"
+ doc.spans["hi"].append(doc[0:3])
+ assert len(doc.spans["hi"]) == 2
+ assert doc.spans["hi"][1].text == "Some text about"
+ assert [span.text for span in doc.spans["hi"]] == ["Colombia", "Some text about"]
+ assert not doc.spans["hi"].has_overlap
+ doc.ents = [Span(doc, 3, 4, label="GPE"), Span(doc, 6, 8, label="GPE")]
+ doc.spans["hi"].extend(doc.ents)
+ assert len(doc.spans["hi"]) == 4
+ assert [span.label_ for span in doc.spans["hi"]] == ["bye", "", "GPE", "GPE"]
+ assert doc.spans["hi"].has_overlap
+ del doc.spans["hi"]
+ assert "hi" not in doc.spans
diff --git a/spacy/tests/doc/test_graph.py b/spacy/tests/doc/test_graph.py
new file mode 100644
index 000000000..d5e2c05d1
--- /dev/null
+++ b/spacy/tests/doc/test_graph.py
@@ -0,0 +1,57 @@
+from spacy.vocab import Vocab
+from spacy.tokens.doc import Doc
+from spacy.tokens.graph import Graph
+
+
+def test_graph_init():
+ doc = Doc(Vocab(), words=["a", "b", "c", "d"])
+ graph = Graph(doc, name="hello")
+ assert graph.name == "hello"
+ assert graph.doc is doc
+
+
+def test_graph_edges_and_nodes():
+ doc = Doc(Vocab(), words=["a", "b", "c", "d"])
+ graph = Graph(doc, name="hello")
+ node1 = graph.add_node((0,))
+ assert graph.get_node((0,)) == node1
+ node2 = graph.add_node((1, 3))
+ assert list(node2) == [1, 3]
+ graph.add_edge(
+ node1,
+ node2,
+ label="one",
+ weight=-10.5
+ )
+ assert graph.has_edge(
+ node1,
+ node2,
+ label="one"
+ )
+ assert node1.heads() == []
+ assert [tuple(h) for h in node2.heads()] == [(0,)]
+ assert [tuple(t) for t in node1.tails()] == [(1, 3)]
+ assert [tuple(t) for t in node2.tails()] == []
+
+
+def test_graph_walk():
+ doc = Doc(Vocab(), words=["a", "b", "c", "d"])
+ graph = Graph(
+ doc,
+ name="hello",
+ nodes=[(0,), (1,), (2,), (3,)],
+ edges=[(0, 1), (0, 2), (0, 3), (3, 0)],
+ labels=None,
+ weights=None
+ )
+ node0, node1, node2, node3 = list(graph.nodes)
+ assert [tuple(h) for h in node0.heads()] == [(3,)]
+ assert [tuple(h) for h in node1.heads()] == [(0,)]
+ assert [tuple(h) for h in node0.walk_heads()] == [(3,), (0,)]
+ assert [tuple(h) for h in node1.walk_heads()] == [(0,), (3,), (0,)]
+ assert [tuple(h) for h in node2.walk_heads()] == [(0,), (3,), (0,)]
+ assert [tuple(h) for h in node3.walk_heads()] == [(0,), (3,)]
+ assert [tuple(t) for t in node0.walk_tails()] == [(1,), (2,), (3,), (0,)]
+ assert [tuple(t) for t in node1.walk_tails()] == []
+ assert [tuple(t) for t in node2.walk_tails()] == []
+ assert [tuple(t) for t in node3.walk_tails()] == [(0,), (1,), (2,), (3,)]
diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py
index 00b9d12d4..837c128af 100644
--- a/spacy/tests/serialize/test_serialize_doc.py
+++ b/spacy/tests/serialize/test_serialize_doc.py
@@ -56,6 +56,13 @@ def test_serialize_doc_exclude(en_vocab):
assert not new_doc.user_data
+def test_serialize_doc_span_groups(en_vocab):
+ doc = Doc(en_vocab, words=["hello", "world", "!"])
+ doc.spans["content"] = [doc[0:2]]
+ new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
+ assert len(new_doc.spans["content"]) == 1
+
+
def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
texts = ["Some text", "Lots of texts...", "..."]
@@ -63,6 +70,7 @@ def test_serialize_doc_bin():
nlp = English()
for doc in nlp.pipe(texts):
doc.cats = cats
+ doc.spans["start"] = [doc[0:2]]
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
@@ -73,6 +81,7 @@ def test_serialize_doc_bin():
for i, doc in enumerate(reloaded_docs):
assert doc.text == texts[i]
assert doc.cats == cats
+ assert len(doc.spans) == 1
def test_serialize_doc_bin_unknown_spaces(en_vocab):
diff --git a/spacy/tokens/_dict_proxies.py b/spacy/tokens/_dict_proxies.py
new file mode 100644
index 000000000..b10f6d484
--- /dev/null
+++ b/spacy/tokens/_dict_proxies.py
@@ -0,0 +1,49 @@
+from typing import Iterable, Tuple, Union, TYPE_CHECKING
+import weakref
+from collections import UserDict
+import srsly
+
+from .span_group import SpanGroup
+
+if TYPE_CHECKING:
+ # This lets us add type hints for mypy etc. without causing circular imports
+ from .doc import Doc # noqa: F401
+ from .span import Span # noqa: F401
+
+
+# Why inherit from UserDict instead of dict here?
+# Well, the 'dict' class doesn't necessarily delegate everything nicely,
+# for performance reasons. The UserDict is slower by better behaved.
+# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
+class SpanGroups(UserDict):
+ """A dict-like proxy held by the Doc, to control access to span groups."""
+
+ def __init__(
+ self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple()
+ ) -> None:
+ self.doc_ref = weakref.ref(doc)
+ UserDict.__init__(self, items)
+
+ def __setitem__(self, key: str, value: Union[SpanGroup, Iterable["Span"]]) -> None:
+ if not isinstance(value, SpanGroup):
+ value = self._make_span_group(key, value)
+ assert value.doc is self.doc_ref()
+ UserDict.__setitem__(self, key, value)
+
+ def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
+ return SpanGroup(self.doc_ref(), name=name, spans=spans)
+
+ def to_bytes(self) -> bytes:
+ # We don't need to serialize this as a dict, because the groups
+ # know their names.
+ msg = [value.to_bytes() for value in self.values()]
+ return srsly.msgpack_dumps(msg)
+
+ def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
+ msg = srsly.msgpack_loads(bytes_data)
+ self.clear()
+ doc = self.doc_ref()
+ for value_bytes in msg:
+ group = SpanGroup(doc).from_bytes(value_bytes)
+ self[group.name] = group
+ return self
diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py
index 821f55eb6..bb1f515ec 100644
--- a/spacy/tokens/_serialize.py
+++ b/spacy/tokens/_serialize.py
@@ -33,6 +33,7 @@ class DocBin:
{
"attrs": List[uint64], # e.g. [TAG, HEAD, ENT_IOB, ENT_TYPE]
"tokens": bytes, # Serialized numpy uint64 array with the token data
+ "spans": List[Dict[str, bytes]], # SpanGroups data for each doc
"spaces": bytes, # Serialized numpy boolean array with spaces data
"lengths": bytes, # Serialized numpy int32 array with the doc lengths
"strings": List[unicode] # List of unique strings in the token data
@@ -70,6 +71,7 @@ class DocBin:
self.tokens = []
self.spaces = []
self.cats = []
+ self.span_groups = []
self.user_data = []
self.flags = []
self.strings = set()
@@ -107,6 +109,10 @@ class DocBin:
self.strings.add(token.ent_kb_id_)
self.cats.append(doc.cats)
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
+ self.span_groups.append(doc.spans.to_bytes())
+ for key, group in doc.spans.items():
+ for span in group:
+ self.strings.add(span.label_)
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
"""Recover Doc objects from the annotations, using the given vocab.
@@ -130,6 +136,10 @@ class DocBin:
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
doc = doc.from_array(self.attrs, tokens)
doc.cats = self.cats[i]
+ if self.span_groups[i]:
+ doc.spans.from_bytes(self.span_groups[i])
+ else:
+ doc.spans.clear()
if i < len(self.user_data) and self.user_data[i] is not None:
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
doc.user_data.update(user_data)
@@ -161,6 +171,7 @@ class DocBin:
self.spaces.extend(other.spaces)
self.strings.update(other.strings)
self.cats.extend(other.cats)
+ self.span_groups.extend(other.span_groups)
self.flags.extend(other.flags)
self.user_data.extend(other.user_data)
@@ -185,6 +196,7 @@ class DocBin:
"strings": list(sorted(self.strings)),
"cats": self.cats,
"flags": self.flags,
+ "span_groups": self.span_groups,
}
if self.store_user_data:
msg["user_data"] = self.user_data
@@ -213,6 +225,7 @@ class DocBin:
self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
self.cats = msg["cats"]
+ self.span_groups = msg.get("span_groups", [b"" for _ in lengths])
self.flags = msg.get("flags", [{} for _ in lengths])
if "user_data" in msg:
self.user_data = list(msg["user_data"])
diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd
index 08f795b1a..c74ee0b63 100644
--- a/spacy/tokens/doc.pxd
+++ b/spacy/tokens/doc.pxd
@@ -2,7 +2,7 @@ from cymem.cymem cimport Pool
cimport numpy as np
from ..vocab cimport Vocab
-from ..structs cimport TokenC, LexemeC
+from ..structs cimport TokenC, LexemeC, SpanC
from ..typedefs cimport attr_t
from ..attrs cimport attr_id_t
@@ -33,6 +33,7 @@ cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
+
cdef class Doc:
cdef readonly Pool mem
cdef readonly Vocab vocab
@@ -43,6 +44,7 @@ cdef class Doc:
cdef public object tensor
cdef public object cats
cdef public object user_data
+ cdef readonly object spans
cdef TokenC* c
diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx
index 9eedf214b..1f5842948 100644
--- a/spacy/tokens/doc.pyx
+++ b/spacy/tokens/doc.pyx
@@ -16,6 +16,7 @@ from thinc.util import copy_array
import warnings
from .span cimport Span
+from ._dict_proxies import SpanGroups
from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME
from ..typedefs cimport attr_t, flags_t
@@ -222,6 +223,7 @@ cdef class Doc:
self.vocab = vocab
size = max(20, (len(words) if words is not None else 0))
self.mem = Pool()
+ self.spans = SpanGroups(self)
# Guarantee self.lex[i-x], for any i >= 0 and x < padding is in bounds
# However, we need to remember the true starting places, so that we can
# realloc.
@@ -1255,6 +1257,9 @@ cdef class Doc:
strings.add(token.ent_kb_id_)
strings.add(token.ent_id_)
strings.add(token.norm_)
+ for group in self.spans.values():
+ for span in group:
+ strings.add(span.label_)
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
@@ -1266,6 +1271,7 @@ cdef class Doc:
"sentiment": lambda: self.sentiment,
"tensor": lambda: self.tensor,
"cats": lambda: self.cats,
+ "spans": lambda: self.spans.to_bytes(),
"strings": lambda: list(strings),
"has_unknown_spaces": lambda: self.has_unknown_spaces
}
@@ -1290,18 +1296,6 @@ cdef class Doc:
"""
if self.length != 0:
raise ValueError(Errors.E033.format(length=self.length))
- deserializers = {
- "text": lambda b: None,
- "array_head": lambda b: None,
- "array_body": lambda b: None,
- "sentiment": lambda b: None,
- "tensor": lambda b: None,
- "cats": lambda b: None,
- "strings": lambda b: None,
- "user_data_keys": lambda b: None,
- "user_data_values": lambda b: None,
- "has_unknown_spaces": lambda b: None
- }
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
@@ -1336,9 +1330,12 @@ cdef class Doc:
self.push_back(lex, has_space)
start = end + has_space
self.from_array(msg["array_head"][2:], attrs[:, 2:])
+ if "spans" in msg:
+ self.spans.from_bytes(msg["spans"])
+ else:
+ self.spans.clear()
return self
-
def extend_tensor(self, tensor):
"""Concatenate a new tensor onto the doc.tensor object.
diff --git a/spacy/tokens/graph.pxd b/spacy/tokens/graph.pxd
new file mode 100644
index 000000000..6f2f80656
--- /dev/null
+++ b/spacy/tokens/graph.pxd
@@ -0,0 +1,13 @@
+from libcpp.vector cimport vector
+from cymem.cymem cimport Pool
+from preshed.maps cimport PreshMap
+from ..structs cimport GraphC, EdgeC
+
+
+cdef class Graph:
+ cdef GraphC c
+ cdef Pool mem
+ cdef PreshMap node_map
+ cdef PreshMap edge_map
+ cdef object doc_ref
+ cdef public str name
diff --git a/spacy/tokens/graph.pyx b/spacy/tokens/graph.pyx
new file mode 100644
index 000000000..9351435f8
--- /dev/null
+++ b/spacy/tokens/graph.pyx
@@ -0,0 +1,709 @@
+# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
+from typing import List, Tuple, Generator
+from libc.stdint cimport int32_t, int64_t
+from libcpp.pair cimport pair
+from libcpp.unordered_map cimport unordered_map
+from libcpp.unordered_set cimport unordered_set
+from cython.operator cimport dereference
+cimport cython
+import weakref
+from preshed.maps cimport map_get_unless_missing
+from murmurhash.mrmr cimport hash64
+from ..typedefs cimport hash_t
+from ..strings import get_string_id
+from ..structs cimport EdgeC, GraphC
+from .token import Token
+
+
+@cython.freelist(8)
+cdef class Edge:
+ cdef readonly Graph graph
+ cdef readonly int i
+
+ def __init__(self, Graph graph, int i):
+ self.graph = graph
+ self.i = i
+
+ @property
+ def is_none(self) -> bool:
+ return False
+
+ @property
+ def doc(self) -> "Doc":
+ return self.graph.doc
+
+ @property
+ def head(self) -> "Node":
+ return Node(self.graph, self.graph.c.edges[self.i].head)
+
+ @property
+ def tail(self) -> "Tail":
+ return Node(self.graph, self.graph.c.edges[self.i].tail)
+
+ @property
+ def label(self) -> int:
+ return self.graph.c.edges[self.i].label
+
+ @property
+ def weight(self) -> float:
+ return self.graph.c.weights[self.i]
+
+ @property
+ def label_(self) -> str:
+ return self.doc.vocab.strings[self.label]
+
+
+@cython.freelist(8)
+cdef class Node:
+ cdef readonly Graph graph
+ cdef readonly int i
+
+ def __init__(self, Graph graph, int i):
+ """A reference to a node of an annotation graph. Each node is made up of
+ an ordered set of zero or more token indices.
+
+ Node references are usually created by the Graph object itself, or from
+ the Node or Edge objects. You usually won't need to instantiate this
+ class yourself.
+ """
+ cdef int length = graph.c.nodes.size()
+ if i >= length or -i >= length:
+ raise IndexError(f"Node index {i} out of bounds ({length})")
+ if i < 0:
+ i += length
+ self.graph = graph
+ self.i = i
+
+ def __eq__(self, other):
+ if self.graph is not other.graph:
+ return False
+ else:
+ return self.i == other.i
+
+ def __iter__(self) -> Generator[int]:
+ for i in self.graph.c.nodes[self.i]:
+ yield i
+
+ def __getitem__(self, int i) -> int:
+ """Get a token index from the node's set of tokens."""
+ length = self.graph.c.nodes[self.i].size()
+ if i >= length or -i >= length:
+ raise IndexError(f"Token index {i} out of bounds ({length})")
+ if i < 0:
+ i += length
+ return self.graph.c.nodes[self.i][i]
+
+ def __len__(self) -> int:
+ """The number of tokens that make up the node."""
+ return self.graph.c.nodes[self.i].size()
+
+ @property
+ def is_none(self) -> bool:
+ """Whether the node is a special value, indicating 'none'.
+
+ The NoneNode type is returned by the Graph, Edge and Node objects when
+ there is no match to a query. It has the same API as Node, but it always
+ returns NoneNode, NoneEdge or empty lists for its queries.
+ """
+ return False
+
+ @property
+ def doc(self) -> "Doc":
+ """The Doc object that the graph refers to."""
+ return self.graph.doc
+
+ @property
+ def tokens(self) -> Tuple[Token]:
+ """A tuple of Token objects that make up the node."""
+ doc = self.doc
+ return tuple([doc[i] for i in self])
+
+ def head(self, i=None, label=None) -> "Node":
+ """Get the head of the first matching edge, searching by index, label,
+ both or neither.
+
+ For instance, `node.head(i=1)` will get the head of the second edge that
+ this node is a tail of. `node.head(i=1, label="ARG0")` will further
+ check that the second edge has the label `"ARG0"`.
+
+ If no matching node can be found, the graph's NoneNode is returned.
+ """
+ return self.headed(i=i, label=label)
+
+ def tail(self, i=None, label=None) -> "Node":
+ """Get the tail of the first matching edge, searching by index, label,
+ both or neither.
+
+ If no matching node can be found, the graph's NoneNode is returned.
+ """
+ return self.tailed(i=i, label=label).tail
+
+ def sibling(self, i=None, label=None):
+ """Get the first matching sibling node. Two nodes are siblings if they
+ are both tails of the same head.
+ If no matching node can be found, the graph's NoneNode is returned.
+ """
+ if i is None:
+ siblings = self.siblings(label=label)
+ return siblings[0] if siblings else NoneNode(self)
+ else:
+ edges = []
+ for h in self.headed():
+ edges.extend([e for e in h.tailed() if e.tail.i != self.i])
+ if i >= len(edges):
+ return NoneNode(self)
+ elif label is not None and edges[i].label != label:
+ return NoneNode(self)
+ else:
+ return edges[i].tail
+
+ def heads(self, label=None) -> List["Node"]:
+ """Find all matching heads of this node."""
+ cdef vector[int] edge_indices
+ self._find_edges(edge_indices, "head", label)
+ return [Node(self.graph, self.graph.c.edges[i].head) for i in edge_indices]
+
+ def tails(self, label=None) -> List["Node"]:
+ """Find all matching tails of this node."""
+ cdef vector[int] edge_indices
+ self._find_edges(edge_indices, "tail", label)
+ return [Node(self.graph, self.graph.c.edges[i].tail) for i in edge_indices]
+
+ def siblings(self, label=None) -> List["Node"]:
+ """Find all maching siblings of this node. Two nodes are siblings if they
+ are tails of the same head.
+ """
+ edges = []
+ for h in self.headed():
+ edges.extend([e for e in h.tailed() if e.tail.i != self.i])
+ if label is None:
+ return [e.tail for e in edges]
+ else:
+ return [e.tail for e in edges if e.label == label]
+
+ def headed(self, i=None, label=None) -> Edge:
+ """Find the first matching edge headed by this node.
+ If no matching edge can be found, the graph's NoneEdge is returned.
+ """
+ start, end = self._find_range(i, self.c.n_head[self.i])
+ idx = self._find_edge("head", start, end, label)
+ if idx == -1:
+ return NoneEdge(self.graph)
+ else:
+ return Edge(self.graph, idx)
+
+ def tailed(self, i=None, label=None) -> Edge:
+ """Find the first matching edge tailed by this node.
+ If no matching edge can be found, the graph's NoneEdge is returned.
+ """
+ start, end = self._find_range(i, self.c.n_tail[self.i])
+ idx = self._find_edge("tail", start, end, label)
+ if idx == -1:
+ return NoneEdge(self.graph)
+ else:
+ return Edge(self.graph, idx)
+
+ def headeds(self, label=None) -> List[Edge]:
+ """Find all matching edges headed by this node."""
+ cdef vector[int] edge_indices
+ self._find_edges(edge_indices, "head", label)
+ return [Edge(self.graph, i) for i in edge_indices]
+
+ def taileds(self, label=None) -> List["Edge"]:
+ """Find all matching edges headed by this node."""
+ cdef vector[int] edge_indices
+ self._find_edges(edge_indices, "tail", label)
+ return [Edge(self.graph, i) for i in edge_indices]
+
+ def walk_heads(self):
+ cdef vector[int] node_indices
+ walk_head_nodes(node_indices, &self.graph.c, self.i)
+ for i in node_indices:
+ yield Node(self.graph, i)
+
+ def walk_tails(self):
+ cdef vector[int] node_indices
+ walk_tail_nodes(node_indices, &self.graph.c, self.i)
+ for i in node_indices:
+ yield Node(self.graph, i)
+
+ cdef (int, int) _get_range(self, i, n):
+ if i is None:
+ return (0, n)
+ elif i < n:
+ return (i, i+1)
+ else:
+ return (0, 0)
+
+ cdef int _find_edge(self, str direction, int start, int end, label) except -2:
+ if direction == "head":
+ get_edges = get_head_edges
+ else:
+ get_edges = get_tail_edges
+ cdef vector[int] edge_indices
+ get_edges(edge_indices, &self.graph.c, self.i)
+ if label is None:
+ return edge_indices[start]
+ for edge_index in edge_indices[start:end]:
+ if self.graph.c.edges[edge_index].label == label:
+ return edge_index
+ else:
+ return -1
+
+ cdef int _find_edges(self, vector[int]& edge_indices, str direction, label):
+ if direction == "head":
+ get_edges = get_head_edges
+ else:
+ get_edges = get_tail_edges
+ if label is None:
+ get_edges(edge_indices, &self.graph.c, self.i)
+ return edge_indices.size()
+ cdef vector[int] unfiltered
+ get_edges(unfiltered, &self.graph.c, self.i)
+ for edge_index in unfiltered:
+ if self.graph.c.edges[edge_index].label == label:
+ edge_indices.push_back(edge_index)
+ return edge_indices.size()
+
+
+cdef class NoneEdge(Edge):
+ """An Edge subclass, representing a non-result. The NoneEdge has the same
+ API as other Edge instances, but always returns NoneEdge, NoneNode, or empty
+ lists.
+ """
+ def __init__(self, graph):
+ self.graph = graph
+ self.i = -1
+
+ @property
+ def doc(self) -> "Doc":
+ return self.graph.doc
+
+ @property
+ def head(self) -> "NoneNode":
+ return NoneNode(self.graph)
+
+ @property
+ def tail(self) -> "NoneNode":
+ return NoneNode(self.graph)
+
+ @property
+ def label(self) -> int:
+ return 0
+
+ @property
+ def weight(self) -> float:
+ return 0.0
+
+ @property
+ def label_(self) -> str:
+ return ""
+
+
+cdef class NoneNode(Node):
+ def __init__(self, graph):
+ self.graph = graph
+ self.i = -1
+
+ def __getitem__(self, int i):
+ raise IndexError("Cannot index into NoneNode.")
+
+ def __len__(self):
+ return 0
+
+ @property
+ def is_none(self):
+ return -1
+
+ @property
+ def doc(self):
+ return self.graph.doc
+
+ @property
+ def tokens(self):
+ return tuple()
+
+ def head(self, i=None, label=None):
+ return self
+
+ def tail(self, i=None, label=None):
+ return self
+
+ def walk_heads(self):
+ yield from []
+
+ def walk_tails(self):
+ yield from []
+
+
+cdef class Graph:
+ """A set of directed labelled relationships between sets of tokens.
+
+ EXAMPLE:
+ Construction 1
+ >>> graph = Graph(doc, name="srl")
+
+ Construction 2
+ >>> graph = Graph(
+ doc,
+ name="srl",
+ nodes=[(0,), (1, 3), (,)],
+ edges=[(0, 2), (2, 1)]
+ )
+
+ Construction 3
+ >>> graph = Graph(
+ doc,
+ name="srl",
+ nodes=[(0,), (1, 3), (,)],
+ edges=[(2, 0), (0, 1)],
+ labels=["word sense ID 1675", "agent"],
+ weights=[-42.6, -1.7]
+ )
+ >>> assert graph.has_node((0,))
+ >>> assert graph.has_edge((0,), (1,3), label="agent")
+ """
+ def __init__(self, doc, *, name="", nodes=[], edges=[], labels=None, weights=None):
+ """Create a Graph object.
+
+ doc (Doc): The Doc object the graph will refer to.
+ name (str): A string name to help identify the graph. Defaults to "".
+ nodes (List[Tuple[int]]): A list of token-index tuples to add to the graph
+ as nodes. Defaults to [].
+ edges (List[Tuple[int, int]]): A list of edges between the provided nodes.
+ Each edge should be a (head, tail) tuple, where `head` and `tail`
+ are integers pointing into the `nodes` list. Defaults to [].
+ labels (Optional[List[str]]): A list of labels for the provided edges.
+ If None, all of the edges specified by the edges argument will have
+ be labelled with the empty string (""). If `labels` is not `None`,
+ it must have the same length as the `edges` argument.
+ weights (Optional[List[float]]): A list of weights for the provided edges.
+ If None, all of the edges specified by the edges argument will
+ have the weight 0.0. If `weights` is not `None`, it must have the
+ same length as the `edges` argument.
+ """
+ if weights is not None:
+ assert len(weights) == len(edges)
+ else:
+ weights = [0.0] * len(edges)
+ if labels is not None:
+ assert len(labels) == len(edges)
+ else:
+ labels = [""] * len(edges)
+ self.c.node_map = new unordered_map[hash_t, int]()
+ self.c.edge_map = new unordered_map[hash_t, int]()
+ self.c.roots = new unordered_set[int]()
+ self.name = name
+ self.doc_ref = weakref.ref(doc)
+ for node in nodes:
+ self.add_node(node)
+ for (head, tail), label, weight in zip(edges, labels, weights):
+ self.add_edge(
+ Node(self, head),
+ Node(self, tail),
+ label=label,
+ weight=weight
+ )
+
+ def __dealloc__(self):
+ del self.c.node_map
+ del self.c.edge_map
+ del self.c.roots
+
+ @property
+ def doc(self) -> "Doc":
+ """The Doc object the graph refers to."""
+ return self.doc_ref()
+
+ @property
+ def edges(self) -> Generator[Edge]:
+ """Iterate over the edges in the graph."""
+ for i in range(self.c.edges.size()):
+ yield Edge(self, i)
+
+ @property
+ def nodes(self) -> Generator[Node]:
+ """Iterate over the nodes in the graph."""
+ for i in range(self.c.nodes.size()):
+ yield Node(self, i)
+
+ def add_edge(self, head, tail, *, label="", weight=None) -> Edge:
+ """Add an edge to the graph, connecting two groups of tokens.
+
+ If there is already an edge for the (head, tail, label) triple, it will
+ be returned, and no new edge will be created. The weight of the edge
+ will be updated if a weight is specified.
+ """
+ label_hash = self.doc.vocab.strings.as_int(label)
+ weight_float = weight if weight is not None else 0.0
+ edge_index = add_edge(
+ &self.c,
+ EdgeC(
+ head=self.add_node(head).i,
+ tail=self.add_node(tail).i,
+ label=self.doc.vocab.strings.as_int(label),
+ ),
+ weight=weight if weight is not None else 0.0
+ )
+ return Edge(self, edge_index)
+
+ def get_edge(self, head, tail, *, label="") -> Edge:
+ """Look up an edge in the graph. If the graph has no matching edge,
+ the NoneEdge object is returned.
+ """
+ head_node = self.get_node(head)
+ if head_node.is_none:
+ return NoneEdge(self)
+ tail_node = self.get_node(tail)
+ if tail_node.is_none:
+ return NoneEdge(self)
+ edge_index = get_edge(
+ &self.c,
+ EdgeC(head=head_node.i, tail=tail_node.i, label=get_string_id(label))
+ )
+ if edge_index < 0:
+ return NoneEdge(self)
+ else:
+ return Edge(self, edge_index)
+
+ def has_edge(self, head, tail, label) -> bool:
+ """Check whether a (head, tail, label) triple is an edge in the graph."""
+ return not self.get_edge(head, tail, label=label).is_none
+
+ def add_node(self, indices) -> Node:
+ """Add a node to the graph and return it. Nodes refer to ordered sets
+ of token indices.
+
+ This method is idempotent: if there is already a node for the given
+ indices, it is returned without a new node being created.
+ """
+ if isinstance(indices, Node):
+ return indices
+ cdef vector[int32_t] node
+ node.reserve(len(indices))
+ for idx in indices:
+ node.push_back(idx)
+ i = add_node(&self.c, node)
+ print("Add node", indices, i)
+ return Node(self, i)
+
+ def get_node(self, indices) -> Node:
+ """Get a node from the graph, or the NoneNode if there is no node for
+ the given indices.
+ """
+ if isinstance(indices, Node):
+ return indices
+ cdef vector[int32_t] node
+ node.reserve(len(indices))
+ for idx in indices:
+ node.push_back(idx)
+ node_index = get_node(&self.c, node)
+ if node_index < 0:
+ return NoneNode(self)
+ else:
+ print("Get node", indices, node_index)
+ return Node(self, node_index)
+
+ def has_node(self, tuple indices) -> bool:
+ """Check whether the graph has a node for the given indices."""
+ return not self.get_node(indices).is_none
+
+
+cdef int add_edge(GraphC* graph, EdgeC edge, float weight) nogil:
+ key = hash64(&edge, sizeof(edge), 0)
+ it = graph.edge_map.find(key)
+ if it != graph.edge_map.end():
+ edge_index = dereference(it).second
+ graph.weights[edge_index] = weight
+ return edge_index
+ else:
+ edge_index = graph.edges.size()
+ graph.edge_map.insert(pair[hash_t, int](key, edge_index))
+ graph.edges.push_back(edge)
+ if graph.n_tails[edge.head] == 0:
+ graph.first_tail[edge.head] = edge_index
+ if graph.n_heads[edge.tail] == 0:
+ graph.first_head[edge.tail] = edge_index
+ graph.n_tails[edge.head] += 1
+ graph.n_heads[edge.tail] += 1
+ graph.weights.push_back(weight)
+ # If we had the tail marked as a root, remove it.
+ tail_root_index = graph.roots.find(edge.tail)
+ if tail_root_index != graph.roots.end():
+ graph.roots.erase(tail_root_index)
+ return edge_index
+
+
+cdef int get_edge(const GraphC* graph, EdgeC edge) nogil:
+ key = hash64(&edge, sizeof(edge), 0)
+ it = graph.edge_map.find(key)
+ if it == graph.edge_map.end():
+ return -1
+ else:
+ return dereference(it).second
+
+
+cdef int has_edge(const GraphC* graph, EdgeC edge) nogil:
+ return get_edge(graph, edge) >= 0
+
+
+cdef int add_node(GraphC* graph, vector[int32_t]& node) nogil:
+ key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
+ it = graph.node_map.find(key)
+ if it != graph.node_map.end():
+ # Item found. Convert the iterator to an index value.
+ return dereference(it).second
+ else:
+ index = graph.nodes.size()
+ graph.nodes.push_back(node)
+ graph.n_heads.push_back(0)
+ graph.n_tails.push_back(0)
+ graph.first_head.push_back(0)
+ graph.first_tail.push_back(0)
+ graph.roots.insert(index)
+ graph.node_map.insert(pair[hash_t, int](key, index))
+ return index
+
+
+cdef int get_node(const GraphC* graph, vector[int32_t] node) nogil:
+ key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
+ it = graph.node_map.find(key)
+ if it == graph.node_map.end():
+ return -1
+ else:
+ return dereference(it).second
+
+
+cdef int has_node(const GraphC* graph, vector[int32_t] node) nogil:
+ return get_node(graph, node) >= 0
+
+
+cdef int get_head_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
+ todo = graph.n_heads[node]
+ if todo == 0:
+ return 0
+ output.reserve(output.size() + todo)
+ start = graph.first_head[node]
+ end = graph.edges.size()
+ for i in range(start, end):
+ if todo <= 0:
+ break
+ elif graph.edges[i].tail == node:
+ output.push_back(graph.edges[i].head)
+ todo -= 1
+ return todo
+
+
+cdef int get_tail_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
+ todo = graph.n_tails[node]
+ if todo == 0:
+ return 0
+ output.reserve(output.size() + todo)
+ start = graph.first_tail[node]
+ end = graph.edges.size()
+ for i in range(start, end):
+ if todo <= 0:
+ break
+ elif graph.edges[i].head == node:
+ output.push_back(graph.edges[i].tail)
+ todo -= 1
+ return todo
+
+
+cdef int get_sibling_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
+ cdef vector[int] heads
+ cdef vector[int] tails
+ get_head_nodes(heads, graph, node)
+ for i in range(heads.size()):
+ get_tail_nodes(tails, graph, heads[i])
+ for j in range(tails.size()):
+ if tails[j] != node:
+ output.push_back(tails[j])
+ tails.clear()
+ return output.size()
+
+
+cdef int get_head_edges(vector[int]& output, const GraphC* graph, int node) nogil:
+ todo = graph.n_heads[node]
+ if todo == 0:
+ return 0
+ output.reserve(output.size() + todo)
+ start = graph.first_head[node]
+ end = graph.edges.size()
+ for i in range(start, end):
+ if todo <= 0:
+ break
+ elif graph.edges[i].tail == node:
+ output.push_back(i)
+ todo -= 1
+ return todo
+
+
+cdef int get_tail_edges(vector[int]& output, const GraphC* graph, int node) nogil:
+ todo = graph.n_tails[node]
+ if todo == 0:
+ return 0
+ output.reserve(output.size() + todo)
+ start = graph.first_tail[node]
+ end = graph.edges.size()
+ for i in range(start, end):
+ if todo <= 0:
+ break
+ elif graph.edges[i].head == node:
+ output.push_back(i)
+ todo -= 1
+ return todo
+
+
+cdef int walk_head_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
+ cdef unordered_set[int] seen = unordered_set[int]()
+ get_head_nodes(output, graph, node)
+ seen.insert(node)
+ i = 0
+ while i < output.size():
+ with gil:
+ print("Walk up from", output[i])
+ if seen.find(output[i]) == seen.end():
+ seen.insert(output[i])
+ get_head_nodes(output, graph, output[i])
+ i += 1
+ return i
+
+
+cdef int walk_tail_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
+ cdef unordered_set[int] seen = unordered_set[int]()
+ get_tail_nodes(output, graph, node)
+ seen.insert(node)
+ i = 0
+ while i < output.size():
+ if seen.find(output[i]) == seen.end():
+ seen.insert(output[i])
+ get_tail_nodes(output, graph, output[i])
+ i += 1
+ return i
+
+
+cdef int walk_head_edges(vector[int]& output, const GraphC* graph, int node) nogil:
+ cdef unordered_set[int] seen = unordered_set[int]()
+ get_head_edges(output, graph, node)
+ seen.insert(node)
+ i = 0
+ while i < output.size():
+ if seen.find(output[i]) == seen.end():
+ seen.insert(output[i])
+ get_head_edges(output, graph, output[i])
+ i += 1
+ return i
+
+
+cdef int walk_tail_edges(vector[int]& output, const GraphC* graph, int node) nogil:
+ cdef unordered_set[int] seen = unordered_set[int]()
+ get_tail_edges(output, graph, node)
+ seen.insert(node)
+ i = 0
+ while i < output.size():
+ if seen.find(output[i]) == seen.end():
+ seen.insert(output[i])
+ get_tail_edges(output, graph, output[i])
+ i += 1
+ return i
diff --git a/spacy/tokens/span.pxd b/spacy/tokens/span.pxd
index cc6b908bb..78bee0a8c 100644
--- a/spacy/tokens/span.pxd
+++ b/spacy/tokens/span.pxd
@@ -2,18 +2,24 @@ cimport numpy as np
from .doc cimport Doc
from ..typedefs cimport attr_t
+from ..structs cimport SpanC
cdef class Span:
cdef readonly Doc doc
- cdef readonly int start
- cdef readonly int end
- cdef readonly int start_char
- cdef readonly int end_char
- cdef readonly attr_t label
- cdef readonly attr_t kb_id
-
+ cdef SpanC c
cdef public _vector
cdef public _vector_norm
+ @staticmethod
+ cdef inline Span cinit(Doc doc, SpanC span):
+ cdef Span self = Span.__new__(
+ Span,
+ doc,
+ start=span.start,
+ end=span.end
+ )
+ self.c = span
+ return self
+
cpdef np.ndarray to_array(self, object features)
diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx
index 491ba0266..8643816a1 100644
--- a/spacy/tokens/span.pyx
+++ b/spacy/tokens/span.pyx
@@ -97,23 +97,23 @@ cdef class Span:
if not (0 <= start <= end <= len(doc)):
raise IndexError(Errors.E035.format(start=start, end=end, length=len(doc)))
self.doc = doc
- self.start = start
- self.start_char = self.doc[start].idx if start < self.doc.length else 0
- self.end = end
- if end >= 1:
- self.end_char = self.doc[end - 1].idx + len(self.doc[end - 1])
- else:
- self.end_char = 0
if isinstance(label, str):
label = doc.vocab.strings.add(label)
if isinstance(kb_id, str):
kb_id = doc.vocab.strings.add(kb_id)
if label not in doc.vocab.strings:
raise ValueError(Errors.E084.format(label=label))
- self.label = label
+
+ self.c = SpanC(
+ label=label,
+ kb_id=kb_id,
+ start=start,
+ end=end,
+ start_char=doc[start].idx if start < doc.length else 0,
+ end_char=doc[end - 1].idx + len(doc[end - 1]) if end >= 1 else 0,
+ )
self._vector = vector
self._vector_norm = vector_norm
- self.kb_id = kb_id
def __richcmp__(self, Span other, int op):
if other is None:
@@ -123,25 +123,39 @@ cdef class Span:
return True
# <
if op == 0:
- return self.start_char < other.start_char
+ return self.c.start_char < other.c.start_char
# <=
elif op == 1:
- return self.start_char <= other.start_char
+ return self.c.start_char <= other.c.start_char
# ==
elif op == 2:
- return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
+ # Do the cheap comparisons first
+ return (
+ (self.c.start_char == other.c.start_char) and \
+ (self.c.end_char == other.c.end_char) and \
+ (self.c.label == other.c.label) and \
+ (self.c.kb_id == other.c.kb_id) and \
+ (self.doc == other.doc)
+ )
# !=
elif op == 3:
- return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
+ # Do the cheap comparisons first
+ return not (
+ (self.c.start_char == other.c.start_char) and \
+ (self.c.end_char == other.c.end_char) and \
+ (self.c.label == other.c.label) and \
+ (self.c.kb_id == other.c.kb_id) and \
+ (self.doc == other.doc)
+ )
# >
elif op == 4:
- return self.start_char > other.start_char
+ return self.c.start_char > other.c.start_char
# >=
elif op == 5:
- return self.start_char >= other.start_char
+ return self.c.start_char >= other.c.start_char
def __hash__(self):
- return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id))
+ return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
def __len__(self):
"""Get the number of tokens in the span.
@@ -150,9 +164,9 @@ cdef class Span:
DOCS: https://nightly.spacy.io/api/span#len
"""
- if self.end < self.start:
+ if self.c.end < self.c.start:
return 0
- return self.end - self.start
+ return self.c.end - self.c.start
def __repr__(self):
return self.text
@@ -171,10 +185,10 @@ cdef class Span:
return Span(self.doc, start + self.start, end + self.start)
else:
if i < 0:
- token_i = self.end + i
+ token_i = self.c.end + i
else:
- token_i = self.start + i
- if self.start <= token_i < self.end:
+ token_i = self.c.start + i
+ if self.c.start <= token_i < self.c.end:
return self.doc[token_i]
else:
raise IndexError(Errors.E1002)
@@ -186,7 +200,7 @@ cdef class Span:
DOCS: https://nightly.spacy.io/api/span#iter
"""
- for i in range(self.start, self.end):
+ for i in range(self.c.start, self.c.end):
yield self.doc[i]
def __reduce__(self):
@@ -196,7 +210,7 @@ cdef class Span:
def _(self):
"""Custom extension attributes registered via `set_extension`."""
return Underscore(Underscore.span_extensions, self,
- start=self.start_char, end=self.end_char)
+ start=self.c.start_char, end=self.c.end_char)
def as_doc(self, *, bint copy_user_data=False):
"""Create a `Doc` object with a copy of the `Span`'s data.
@@ -242,7 +256,7 @@ cdef class Span:
for i in range(length):
# if the HEAD refers to a token outside this span, find a more appropriate ancestor
token = self[i]
- ancestor_i = token.head.i - self.start # span offset
+ ancestor_i = token.head.i - self.c.start # span offset
if ancestor_i not in range(length):
if DEP in attrs:
array[i, attrs.index(DEP)] = dep
@@ -250,7 +264,7 @@ cdef class Span:
# try finding an ancestor within this span
ancestors = token.ancestors
for ancestor in ancestors:
- ancestor_i = ancestor.i - self.start
+ ancestor_i = ancestor.i - self.c.start
if ancestor_i in range(length):
array[i, head_col] = ancestor_i - i
@@ -279,7 +293,7 @@ cdef class Span:
DOCS: https://nightly.spacy.io/api/span#get_lca_matrix
"""
- return numpy.asarray(_get_lca_matrix(self.doc, self.start, self.end))
+ return numpy.asarray(_get_lca_matrix(self.doc, self.c.start, self.c.end))
def similarity(self, other):
"""Make a semantic similarity estimate. The default estimate is cosine
@@ -373,10 +387,14 @@ cdef class Span:
DOCS: https://nightly.spacy.io/api/span#ents
"""
+ cdef Span ent
ents = []
for ent in self.doc.ents:
- if ent.start >= self.start and ent.end <= self.end:
- ents.append(ent)
+ if ent.c.start >= self.c.start:
+ if ent.c.end <= self.c.end:
+ ents.append(ent)
+ else:
+ break
return ents
@property
@@ -513,7 +531,7 @@ cdef class Span:
# with head==0, i.e. a sentence root. If so, we can return it. The
# longer the span, the more likely it contains a sentence root, and
# in this case we return in linear time.
- for i in range(self.start, self.end):
+ for i in range(self.c.start, self.c.end):
if self.doc.c[i].head == 0:
return self.doc[i]
# If we don't have a sentence root, we do something that's not so
@@ -524,15 +542,15 @@ cdef class Span:
# think this should be okay.
cdef int current_best = self.doc.length
cdef int root = -1
- for i in range(self.start, self.end):
- if self.start <= (i+self.doc.c[i].head) < self.end:
+ for i in range(self.c.start, self.c.end):
+ if self.c.start <= (i+self.doc.c[i].head) < self.c.end:
continue
words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length)
if words_to_root < current_best:
current_best = words_to_root
root = i
if root == -1:
- return self.doc[self.start]
+ return self.doc[self.c.start]
else:
return self.doc[root]
@@ -548,8 +566,8 @@ cdef class Span:
the span.
RETURNS (Span): The newly constructed object.
"""
- start_idx += self.start_char
- end_idx += self.start_char
+ start_idx += self.c.start_char
+ end_idx += self.c.start_char
return self.doc.char_span(start_idx, end_idx)
@property
@@ -628,6 +646,56 @@ cdef class Span:
for word in self.rights:
yield from word.subtree
+ property start:
+ def __get__(self):
+ return self.c.start
+
+ def __set__(self, int start):
+ if start < 0:
+ raise IndexError("TODO")
+ self.c.start = start
+
+ property end:
+ def __get__(self):
+ return self.c.end
+
+ def __set__(self, int end):
+ if end < 0:
+ raise IndexError("TODO")
+ self.c.end = end
+
+ property start_char:
+ def __get__(self):
+ return self.c.start_char
+
+ def __set__(self, int start_char):
+ if start_char < 0:
+ raise IndexError("TODO")
+ self.c.start_char = start_char
+
+ property end_char:
+ def __get__(self):
+ return self.c.end_char
+
+ def __set__(self, int end_char):
+ if end_char < 0:
+ raise IndexError("TODO")
+ self.c.end_char = end_char
+
+ property label:
+ def __get__(self):
+ return self.c.label
+
+ def __set__(self, attr_t label):
+ self.c.label = label
+
+ property kb_id:
+ def __get__(self):
+ return self.c.kb_id
+
+ def __set__(self, attr_t kb_id):
+ self.c.kb_id = kb_id
+
property ent_id:
"""RETURNS (uint64): The entity ID."""
def __get__(self):
diff --git a/spacy/tokens/span_group.pxd b/spacy/tokens/span_group.pxd
new file mode 100644
index 000000000..5074aa275
--- /dev/null
+++ b/spacy/tokens/span_group.pxd
@@ -0,0 +1,10 @@
+from libcpp.vector cimport vector
+from ..structs cimport SpanC
+
+cdef class SpanGroup:
+ cdef public object _doc_ref
+ cdef public str name
+ cdef public dict attrs
+ cdef vector[SpanC] c
+
+ cdef void push_back(self, SpanC span) nogil
diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx
new file mode 100644
index 000000000..5b768994e
--- /dev/null
+++ b/spacy/tokens/span_group.pyx
@@ -0,0 +1,183 @@
+import weakref
+import struct
+import srsly
+from .span cimport Span
+from libc.stdint cimport uint64_t, uint32_t, int32_t
+
+
+cdef class SpanGroup:
+ """A group of spans that all belong to the same Doc object. The group
+ can be named, and you can attach additional attributes to it. Span groups
+ are generally accessed via the `doc.spans` attribute. The `doc.spans`
+ attribute will convert lists of spans into a `SpanGroup` object for you
+ automatically on assignment.
+
+ Example:
+ Construction 1
+ >>> doc = nlp("Their goi ng home")
+ >>> doc.spans["errors"] = SpanGroup(
+ doc,
+ name="errors",
+ spans=[doc[0:1], doc[2:4]],
+ attrs={"annotator": "matt"}
+ )
+
+ Construction 2
+ >>> doc = nlp("Their goi ng home")
+ >>> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+ >>> assert isinstance(doc.spans["errors"], SpanGroup)
+
+ DOCS: https://nightly.spacy.io/api/spangroup
+ """
+ def __init__(self, doc, *, name="", attrs={}, spans=[]):
+ """Create a SpanGroup.
+
+ doc (Doc): The reference Doc object.
+ name (str): The group name.
+ attrs (Dict[str, Any]): Optional JSON-serializable attributes to attach.
+ spans (Iterable[Span]): The spans to add to the group.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#init
+ """
+ # We need to make this a weak reference, so that the Doc object can
+ # own the SpanGroup without circular references. We do want to get
+ # the Doc though, because otherwise the API gets annoying.
+ self._doc_ref = weakref.ref(doc)
+ self.name = name
+ self.attrs = dict(attrs) if attrs is not None else {}
+ cdef Span span
+ for span in spans:
+ self.push_back(span.c)
+
+ def __repr__(self):
+ return str(list(self))
+
+ @property
+ def doc(self):
+ """RETURNS (Doc): The reference document.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#doc
+ """
+ return self._doc_ref()
+
+ @property
+ def has_overlap(self):
+ """RETURNS (bool): Whether the group contains overlapping spans.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#has_overlap
+ """
+ if not len(self):
+ return False
+ sorted_spans = list(sorted(self))
+ last_end = sorted_spans[0].end
+ for span in sorted_spans[1:]:
+ if span.start < last_end:
+ return True
+ last_end = span.end
+ return False
+
+ def __len__(self):
+ """RETURNS (int): The number of spans in the group.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#len
+ """
+ return self.c.size()
+
+ def append(self, Span span):
+ """Add a span to the group. The span must refer to the same Doc
+ object as the span group.
+
+ span (Span): The span to append.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#append
+ """
+ if span.doc is not self.doc:
+ raise ValueError("Cannot add span to group: refers to different Doc.")
+ self.push_back(span.c)
+
+ def extend(self, spans):
+ """Add multiple spans to the group. All spans must refer to the same
+ Doc object as the span group.
+
+ spans (Iterable[Span]): The spans to add.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#extend
+ """
+ cdef Span span
+ for span in spans:
+ self.append(span)
+
+ def __getitem__(self, int i):
+ """Get a span from the group.
+
+ i (int): The item index.
+ RETURNS (Span): The span at the given index.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#getitem
+ """
+ cdef int size = self.c.size()
+ if i < -size or i >= size:
+ raise IndexError(f"list index {i} out of range")
+ if i < 0:
+ i += size
+ return Span.cinit(self.doc, self.c[i])
+
+ def to_bytes(self):
+ """Serialize the SpanGroup's contents to a byte string.
+
+ RETURNS (bytes): The serialized span group.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#to_bytes
+ """
+ output = {"name": self.name, "attrs": self.attrs, "spans": []}
+ for i in range(self.c.size()):
+ span = self.c[i]
+ # The struct.pack here is probably overkill, but it might help if
+ # you're saving tonnes of spans, and it doesn't really add any
+ # complexity. We do take care to specify little-endian byte order
+ # though, to ensure the message can be loaded back on a different
+ # arch.
+ # Q: uint64_t
+ # q: int64_t
+ # L: uint32_t
+ # l: int32_t
+ output["spans"].append(struct.pack(
+ ">QQQllll",
+ span.id,
+ span.kb_id,
+ span.label,
+ span.start,
+ span.end,
+ span.start_char,
+ span.end_char
+ ))
+ return srsly.msgpack_dumps(output)
+
+ def from_bytes(self, bytes_data):
+ """Deserialize the SpanGroup's contents from a byte string.
+
+ bytes_data (bytes): The span group to load.
+ RETURNS (SpanGroup): The deserialized span group.
+
+ DOCS: https://nightly.spacy.io/api/spangroup#from_bytes
+ """
+ msg = srsly.msgpack_loads(bytes_data)
+ self.name = msg["name"]
+ self.attrs = dict(msg["attrs"])
+ self.c.clear()
+ self.c.reserve(len(msg["spans"]))
+ cdef SpanC span
+ for span_data in msg["spans"]:
+ items = struct.unpack(">QQQllll", span_data)
+ span.id = items[0]
+ span.kb_id = items[1]
+ span.label = items[2]
+ span.start = items[3]
+ span.end = items[4]
+ span.start_char = items[5]
+ span.end_char = items[6]
+ self.c.push_back(span)
+ return self
+
+ cdef void push_back(self, SpanC span) nogil:
+ self.c.push_back(span)
diff --git a/website/docs/api/doc.md b/website/docs/api/doc.md
index 16bbc2700..f3521dae3 100644
--- a/website/docs/api/doc.md
+++ b/website/docs/api/doc.md
@@ -575,6 +575,39 @@ objects, if the entity recognizer has been applied.
| ----------- | --------------------------------------------------------------------- |
| **RETURNS** | Entities in the document, one `Span` per entity. ~~Tuple[Span, ...]~~ |
+## Doc.spans {#spans tag="property"}
+
+A dictionary of named span groups, to store and access additional span
+annotations. You can write to it by assigning a list of [`Span`](/api/span)
+objects or a [`SpanGroup`](/api/spangroup) to a given key.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> ```
+
+| Name | Description |
+| ----------- | ------------------------------------------------------------------ |
+| **RETURNS** | The span groups assigned to the document. ~~Dict[str, SpanGroup]~~ |
+
+## Doc.cats {#cats tag="property" model="text classifier"}
+
+Maps a label to a score for categories applied to the document. Typically set by
+the [`TextCategorizer`](/api/textcategorizer).
+
+> #### Example
+>
+> ```python
+> doc = nlp("This is a text about football.")
+> print(doc.cats)
+> ```
+
+| Name | Description |
+| ----------- | ---------------------------------------------------------- |
+| **RETURNS** | The text categories mapped to scores. ~~Dict[str, float]~~ |
+
## Doc.noun_chunks {#noun_chunks tag="property" model="parser"}
Iterate over the base noun phrases in the document. Yields base noun-phrase
@@ -668,23 +701,22 @@ The L2 norm of the document's vector representation.
## Attributes {#attributes}
-| Name | Description |
-| ------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------- |
-| `text` | A string representation of the document text. ~~str~~ |
-| `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ |
-| `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ |
-| `vocab` | The store of lexical types. ~~Vocab~~ |
-| `tensor` 2 | Container for dense vector representations. ~~numpy.ndarray~~ |
-| `cats` 2 | Maps a label to a score for categories applied to the document. The label is a string and the score should be a float. ~~Dict[str, float]~~ |
-| `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ |
-| `lang` 2.1 | Language of the document's vocabulary. ~~int~~ |
-| `lang_` 2.1 | Language of the document's vocabulary. ~~str~~ |
-| `sentiment` | The document's positivity/negativity score, if available. ~~float~~ |
-| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
-| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
-| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
-| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
-| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
+| Name | Description |
+| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- |
+| `text` | A string representation of the document text. ~~str~~ |
+| `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ |
+| `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ |
+| `vocab` | The store of lexical types. ~~Vocab~~ |
+| `tensor` 2 | Container for dense vector representations. ~~numpy.ndarray~~ |
+| `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ |
+| `lang` 2.1 | Language of the document's vocabulary. ~~int~~ |
+| `lang_` 2.1 | Language of the document's vocabulary. ~~str~~ |
+| `sentiment` | The document's positivity/negativity score, if available. ~~float~~ |
+| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
+| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
+| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
+| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
+| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
## Serialization fields {#serialization-fields}
diff --git a/website/docs/api/spangroup.md b/website/docs/api/spangroup.md
new file mode 100644
index 000000000..ba248f376
--- /dev/null
+++ b/website/docs/api/spangroup.md
@@ -0,0 +1,185 @@
+---
+title: SpanGroup
+tag: class
+source: spacy/tokens/span_group.pyx
+new: 3
+---
+
+A group of arbitrary, potentially overlapping [`Span`](/api/span) objects that
+all belong to the same [`Doc`](/api/doc) object. The group can be named, and you
+can attach additional attributes to it. Span groups are generally accessed via
+the [`Doc.spans`](/api/doc#spans) attribute, which will convert lists of spans
+into a `SpanGroup` object for you automatically on assignment. `SpanGroup`
+objects behave similar to `list`s, so you can append `Span` objects to them or
+access a member at a given index.
+
+## SpanGroup.\_\_init\_\_ {#init tag="method"}
+
+Create a `SpanGroup`.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> spans = [doc[0:1], doc[2:4]]
+>
+> # Construction 1
+> from spacy.tokens import SpanGroup
+>
+> group = SpanGroup(doc, name="errors", spans=spans, attrs={"annotator": "matt"})
+> doc.spans["errors"] = group
+>
+> # Construction 2
+> doc.spans["errors"] = spans
+> assert isinstance(doc.spans["errors"], SpanGroup)
+> ```
+
+| Name | Description |
+| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `doc` | The document the span group belongs to. ~~Doc~~ |
+| _keyword-only_ | |
+| `name` | The name of the span group. If the span group is created automatically on assignment to `doc.spans`, the key name is used. Defaults to `""`. ~~str~~ |
+| `attrs` | Optional JSON-serializable attributes to attach to the span group. ~~Dict[str, Any]~~ |
+| `spans` | The spans to add to the span group. ~~Iterable[Span]~~ |
+
+## SpanGroup.doc {#doc tag="property"}
+
+The [`Doc`](/api/doc) object the span group is referring to.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> assert doc.spans["errors"].doc == doc
+> ```
+
+| Name | Description |
+| ----------- | ------------------------------- |
+| **RETURNS** | The reference document. ~~Doc~~ |
+
+## SpanGroup.has_overlap {#has_overlap tag="property"}
+
+Check whether the span group contains overlapping spans.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> assert not doc.spans["errors"].has_overlap
+> doc.spans["errors"].append(doc[1:2])
+> assert doc.spans["errors"].has_overlap
+> ```
+
+| Name | Description |
+| ----------- | -------------------------------------------------- |
+| **RETURNS** | Whether the span group contains overlaps. ~~bool~~ |
+
+## SpanGroup.\_\_len\_\_ {#len tag="method"}
+
+Get the number of spans in the group.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> assert len(doc.spans["errors"]) == 2
+> ```
+
+| Name | Description |
+| ----------- | ----------------------------------------- |
+| **RETURNS** | The number of spans in the group. ~~int~~ |
+
+## SpanGroup.\_\_getitem\_\_ {#getitem tag="method"}
+
+Get a span from the group.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> span = doc.spans["errors"][1]
+> assert span.text == "goi ng"
+> ```
+
+| Name | Description |
+| ----------- | ------------------------------------- |
+| `i` | The item index. ~~int~~ |
+| **RETURNS** | The span at the given index. ~~Span~~ |
+
+## SpanGroup.append {#append tag="method"}
+
+Add a [`Span`](/api/span) object to the group. The span must refer to the same
+[`Doc`](/api/doc) object as the span group.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1]]
+> doc.spans["errors"].append(doc[2:4])
+> assert len(doc.spans["errors"]) == 2
+> ```
+
+| Name | Description |
+| ------ | ---------------------------- |
+| `span` | The span to append. ~~Span~~ |
+
+## SpanGroup.extend {#extend tag="method"}
+
+Add multiple [`Span`](/api/span) objects to the group. All spans must refer to
+the same [`Doc`](/api/doc) object as the span group.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = []
+> doc.spans["errors"].extend([doc[2:4], doc[0:1]])
+> assert len(doc.spans["errors"]) == 2
+> ```
+
+| Name | Description |
+| ------- | ------------------------------------ |
+| `spans` | The spans to add. ~~Iterable[Span]~~ |
+
+## SpanGroup.to_bytes {#to_bytes tag="method"}
+
+Serialize the span group to a bytestring.
+
+> #### Example
+>
+> ```python
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> group_bytes = doc.spans["errors"].to_bytes()
+> ```
+
+| Name | Description |
+| ----------- | ------------------------------------- |
+| **RETURNS** | The serialized `SpanGroup`. ~~bytes~~ |
+
+## SpanGroup.from_bytes {#from_bytes tag="method"}
+
+Load the span group from a bytestring. Modifies the object in place and returns
+it.
+
+> #### Example
+>
+> ```python
+> from spacy.tokens import SpanGroup
+>
+> doc = nlp("Their goi ng home")
+> doc.spans["errors"] = [doc[0:1], doc[2:4]]
+> group_bytes = doc.spans["errors"].to_bytes()
+> new_group = SpanGroup()
+> new_group.from_bytes(group_bytes)
+> ```
+
+| Name | Description |
+| ------------ | ------------------------------------- |
+| `bytes_data` | The data to load from. ~~bytes~~ |
+| **RETURNS** | The `SpanGroup` object. ~~SpanGroup~~ |
diff --git a/website/docs/usage/101/_architecture.md b/website/docs/usage/101/_architecture.md
index b012c4ec0..8fb452895 100644
--- a/website/docs/usage/101/_architecture.md
+++ b/website/docs/usage/101/_architecture.md
@@ -18,15 +18,16 @@ It also orchestrates training and serialization.
### Container objects {#architecture-containers}
-| Name | Description |
-| --------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [`Doc`](/api/doc) | A container for accessing linguistic annotations. |
-| [`DocBin`](/api/docbin) | A collection of `Doc` objects for efficient binary serialization. Also used for [training data](/api/data-formats#binary-training). |
-| [`Example`](/api/example) | A collection of training annotations, containing two `Doc` objects: the reference data and the predictions. |
-| [`Language`](/api/language) | Processing class that turns text into `Doc` objects. Different languages implement their own subclasses of it. The variable is typically called `nlp`. |
-| [`Lexeme`](/api/lexeme) | An entry in the vocabulary. It's a word type with no context, as opposed to a word token. It therefore has no part-of-speech tag, dependency parse etc. |
-| [`Span`](/api/span) | A slice from a `Doc` object. |
-| [`Token`](/api/token) | An individual token — i.e. a word, punctuation symbol, whitespace, etc. |
+| Name | Description |
+| ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`Doc`](/api/doc) | A container for accessing linguistic annotations. |
+| [`DocBin`](/api/docbin) | A collection of `Doc` objects for efficient binary serialization. Also used for [training data](/api/data-formats#binary-training). |
+| [`Example`](/api/example) | A collection of training annotations, containing two `Doc` objects: the reference data and the predictions. |
+| [`Language`](/api/language) | Processing class that turns text into `Doc` objects. Different languages implement their own subclasses of it. The variable is typically called `nlp`. |
+| [`Lexeme`](/api/lexeme) | An entry in the vocabulary. It's a word type with no context, as opposed to a word token. It therefore has no part-of-speech tag, dependency parse etc. |
+| [`Span`](/api/span) | A slice from a `Doc` object. |
+| [`SpanGroup`](/api/spangroup) | A named collection of spans belonging to a `Doc`. |
+| [`Token`](/api/token) | An individual token — i.e. a word, punctuation symbol, whitespace, etc. |
### Processing pipeline {#architecture-pipeline}
diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md
index 9b911b960..6b21ec383 100644
--- a/website/docs/usage/v3.md
+++ b/website/docs/usage/v3.md
@@ -501,7 +501,7 @@ format for documenting argument and return types.
[`AttributeRuler`](/api/attributeruler),
[`SentenceRecognizer`](/api/sentencerecognizer),
[`DependencyMatcher`](/api/dependencymatcher), [`TrainablePipe`](/api/pipe),
- [`Corpus`](/api/corpus)
+ [`Corpus`](/api/corpus), [`SpanGroup`](/api/spangroup),
diff --git a/website/meta/sidebars.json b/website/meta/sidebars.json
index 3799f399b..d3a0726e6 100644
--- a/website/meta/sidebars.json
+++ b/website/meta/sidebars.json
@@ -77,6 +77,7 @@
{ "text": "Language", "url": "/api/language" },
{ "text": "Lexeme", "url": "/api/lexeme" },
{ "text": "Span", "url": "/api/span" },
+ { "text": "SpanGroup", "url": "/api/spangroup" },
{ "text": "Token", "url": "/api/token" }
]
},
diff --git a/website/meta/type-annotations.json b/website/meta/type-annotations.json
index acbc88ae2..8136b3e96 100644
--- a/website/meta/type-annotations.json
+++ b/website/meta/type-annotations.json
@@ -2,6 +2,7 @@
"Doc": "/api/doc",
"Token": "/api/token",
"Span": "/api/span",
+ "SpanGroup": "/api/spangroup",
"Lexeme": "/api/lexeme",
"Example": "/api/example",
"Alignment": "/api/example#alignment-object",