Use marisa_trie for alias2qids and qid2typeqid to save memory
This commit is contained in:
parent
b64b952948
commit
9dc6295d45
|
@ -32,7 +32,6 @@ import os
|
|||
import re
|
||||
|
||||
import marisa_trie
|
||||
import ujson
|
||||
|
||||
from ..data_utils.almond_utils import quoted_pattern_with_space
|
||||
from ..ned.ned_utils import has_overlap, is_banned, normalize_text
|
||||
|
@ -44,6 +43,7 @@ logger = logging.getLogger(__name__)
|
|||
class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.alias2type = marisa_trie.RecordTrie("<p")
|
||||
|
||||
def process_examples(self, examples, split_path, utterance_field):
|
||||
all_token_type_ids, all_token_type_probs, all_token_qids = [], [], []
|
||||
|
@ -71,7 +71,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
|
||||
self.replace_features_inplace(examples, all_token_type_ids, all_token_type_probs, all_token_qids, utterance_field)
|
||||
|
||||
def find_type_ids(self, tokens, answer=None):
|
||||
def find_type_ids(self, tokens, answer):
|
||||
# each subclass should implement their own find_type_ids method
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -103,11 +103,11 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
end += 1
|
||||
gram_text = normalize_text(" ".join(gram))
|
||||
|
||||
if not is_banned(gram_text) and gram_text not in verbs and gram_text in self.all_aliases:
|
||||
if not is_banned(gram_text) and gram_text not in verbs and gram_text in self.alias2type:
|
||||
if has_overlap(start, end, used_aliases):
|
||||
continue
|
||||
|
||||
used_aliases.append([self.typeqid2id.get(self.alias2type[gram_text], self.unk_id), start, end])
|
||||
used_aliases.append([self.typeqid2id.get(self.alias2type[gram_text][0], self.unk_id), start, end])
|
||||
|
||||
for type_id, beg, end in used_aliases:
|
||||
tokens_type_ids[beg:end] = [[type_id] * self.max_features_size] * (end - beg)
|
||||
|
@ -120,7 +120,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
# sort by number of tokens so longer keys get matched first
|
||||
matched_items = sorted(self.all_aliases.keys(token), key=lambda item: len(item), reverse=True)
|
||||
matched_items = sorted(self.alias2type.keys(token), key=lambda item: len(item), reverse=True)
|
||||
found = False
|
||||
for key in matched_items:
|
||||
type = self.alias2type[key]
|
||||
|
@ -161,7 +161,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
end = length
|
||||
while end > i:
|
||||
tokens_str = ' '.join(tokens[i:end])
|
||||
if tokens_str in self.all_aliases:
|
||||
if tokens_str in self.alias2type:
|
||||
# match found
|
||||
found = True
|
||||
tokens_type_ids.extend(
|
||||
|
@ -184,7 +184,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
tokens_text = " ".join(tokens)
|
||||
|
||||
for ent in entities:
|
||||
if ent not in self.all_aliases:
|
||||
if ent not in self.alias2type:
|
||||
continue
|
||||
ent_num_tokens = len(ent.split(' '))
|
||||
idx = tokens_text.index(ent)
|
||||
|
@ -208,12 +208,9 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
|
|||
class NaiveEntityDisambiguator(BaseEntityDisambiguator):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
with open(os.path.join(self.args.database_dir, 'es_material/alias2type.json'), 'r') as fin:
|
||||
# alias2type.json is a big file (>4G); load it only when necessary
|
||||
self.alias2type = ujson.load(fin)
|
||||
all_aliases = marisa_trie.Trie(self.alias2type.keys())
|
||||
self.all_aliases = all_aliases
|
||||
self.alias2type = marisa_trie.RecordTrie("<p").mmap(
|
||||
os.path.join(self.args.database_dir, 'es_material/alias2typeqid.marisa')
|
||||
)
|
||||
|
||||
def find_type_ids(self, tokens, answer=None):
|
||||
tokens_type_ids = self.lookup(
|
||||
|
@ -225,17 +222,13 @@ class NaiveEntityDisambiguator(BaseEntityDisambiguator):
|
|||
class EntityOracleEntityDisambiguator(BaseEntityDisambiguator):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
with open(os.path.join(self.args.database_dir, 'es_material/alias2type.json'), 'r') as fin:
|
||||
# alias2type.json is a big file (>4G); load it only when necessary
|
||||
self.alias2type = ujson.load(fin)
|
||||
all_aliases = marisa_trie.Trie(self.alias2type.keys())
|
||||
self.all_aliases = all_aliases
|
||||
self.alias2type = marisa_trie.RecordTrie("<p").mmap(
|
||||
os.path.join(self.args.database_dir, 'es_material/alias2typeqid.marisa')
|
||||
)
|
||||
|
||||
def find_type_ids(self, tokens, answer):
|
||||
answer_entities = quoted_pattern_with_space.findall(answer)
|
||||
tokens_type_ids = self.lookup_entities(tokens, answer_entities)
|
||||
|
||||
return tokens_type_ids
|
||||
|
||||
|
||||
|
@ -280,11 +273,9 @@ class TypeOracleEntityDisambiguator(BaseEntityDisambiguator):
|
|||
# this is usually caused by paraphrasing where it adds "-" after entity name: "korean-style restaurants"
|
||||
# or add "'" before or after an entity
|
||||
if not re.search(rf'(^|\s){ent}($|\s)', sentence):
|
||||
# print(f'***ent: {ent} {tokens} {answer}')
|
||||
continue
|
||||
|
||||
# ** this should change if thingtalk syntax changes **
|
||||
|
||||
# ( ... [Book|Music|...] ( ) filter id =~ " position and affirm " ) ...'
|
||||
# ... ^^org.schema.Book:Person ( " james clavell " ) ...
|
||||
# ... contains~ ( [award|genre|...] , " booker " ) ...
|
||||
|
|
Binary file not shown.
|
@ -1 +0,0 @@
|
|||
{"dahr el ouahch":"Q8502","sahlet el ouili":"Q39816","abeking":"Q4167410","substrate profiling of tobacco etch virus protease using a novel fluorescence-assisted whole-cell assay":"Q13442814","culture conversion among hiv co-infected multidrug-resistant tuberculosis patients in tugela ferry , south africa":"Q13442814","characterization of burkholderia rhizoxinica and b. endofungorum isolated from clinical specimens":"Q13442814","increasing incidence of geomyces destructans fungus in bats from the czech republic and slovakia":"Q13442814","molecular mapping of movement-associated areas in the avian brain : a motor theory for vocal learning origin":"Q13442814","sar el adra":"Q8502","cell lineage of the ilyanassa embryo : evolutionary acceleration of regional differentiation during early development":"Q13442814","a survey of honey bee colony losses in the u.s. , fall 2007 to spring 2008":"Q13442814","mitochondrial dna haplogroup d4a is a marker for extreme longevity in japan":"Q13442814","cost-effectiveness of newborn circumcision in reducing lifetime hiv risk among u.s. males":"Q13442814","charabech":"Q8502","chiaab rizq":"Q8502"}
|
Binary file not shown.
Loading…
Reference in New Issue