diff --git a/spacy/strings.pyx b/spacy/strings.pyx index d54dcdf1a..d11936d12 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -118,6 +118,11 @@ cdef class StringStore: else: raise TypeError(type(string_or_id)) + def __contains__(self, unicode string): + cdef hash_t key = hash_string(string) + value = self._map.get(key) + return True if value is not NULL else False + def __iter__(self): cdef int i for i in range(self.size): diff --git a/spacy/tests/vocab/test_vocab.py b/spacy/tests/vocab/test_vocab.py index 27aea46ff..f62cddf0e 100644 --- a/spacy/tests/vocab/test_vocab.py +++ b/spacy/tests/vocab/test_vocab.py @@ -43,6 +43,11 @@ def test_symbols(en_vocab): assert en_vocab.strings['LEMMA'] == LEMMA assert en_vocab.strings['ORTH'] == ORTH assert en_vocab.strings['PROB'] == PROB + + +def test_contains(en_vocab): + assert 'Hello' in en_vocab + assert 'LKsdjvlsakdvlaksdvlkasjdvljasdlkfvm' not in en_vocab @pytest.mark.xfail diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 391adfa28..f876bfefb 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -172,6 +172,11 @@ cdef class Vocab: self._by_orth.set(lex.orth, lex) self.length += 1 + def __contains__(self, unicode string): + key = hash_string(string) + lex = self._by_hash.get(key) + return True if lex is not NULL else False + def __iter__(self): cdef attr_t orth cdef size_t addr