Use marisa_trie for alias2qids and qid2typeqid to save memory

This commit is contained in:
mehrad 2021-08-08 15:50:20 -07:00
parent b64b952948
commit 9dc6295d45
No known key found for this signature in database
GPG Key ID: AAF81F778210AE42
4 changed files with 13 additions and 23 deletions

View File

@ -32,7 +32,6 @@ import os
import re
import marisa_trie
import ujson
from ..data_utils.almond_utils import quoted_pattern_with_space
from ..ned.ned_utils import has_overlap, is_banned, normalize_text
@ -44,6 +43,7 @@ logger = logging.getLogger(__name__)
class BaseEntityDisambiguator(AbstractEntityDisambiguator):
def __init__(self, args):
super().__init__(args)
self.alias2type = marisa_trie.RecordTrie("<p")
def process_examples(self, examples, split_path, utterance_field):
all_token_type_ids, all_token_type_probs, all_token_qids = [], [], []
@ -71,7 +71,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
self.replace_features_inplace(examples, all_token_type_ids, all_token_type_probs, all_token_qids, utterance_field)
def find_type_ids(self, tokens, answer=None):
def find_type_ids(self, tokens, answer):
# each subclass should implement their own find_type_ids method
raise NotImplementedError()
@ -103,11 +103,11 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
end += 1
gram_text = normalize_text(" ".join(gram))
if not is_banned(gram_text) and gram_text not in verbs and gram_text in self.all_aliases:
if not is_banned(gram_text) and gram_text not in verbs and gram_text in self.alias2type:
if has_overlap(start, end, used_aliases):
continue
used_aliases.append([self.typeqid2id.get(self.alias2type[gram_text], self.unk_id), start, end])
used_aliases.append([self.typeqid2id.get(self.alias2type[gram_text][0], self.unk_id), start, end])
for type_id, beg, end in used_aliases:
tokens_type_ids[beg:end] = [[type_id] * self.max_features_size] * (end - beg)
@ -120,7 +120,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
while i < len(tokens):
token = tokens[i]
# sort by number of tokens so longer keys get matched first
matched_items = sorted(self.all_aliases.keys(token), key=lambda item: len(item), reverse=True)
matched_items = sorted(self.alias2type.keys(token), key=lambda item: len(item), reverse=True)
found = False
for key in matched_items:
type = self.alias2type[key]
@ -161,7 +161,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
end = length
while end > i:
tokens_str = ' '.join(tokens[i:end])
if tokens_str in self.all_aliases:
if tokens_str in self.alias2type:
# match found
found = True
tokens_type_ids.extend(
@ -184,7 +184,7 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
tokens_text = " ".join(tokens)
for ent in entities:
if ent not in self.all_aliases:
if ent not in self.alias2type:
continue
ent_num_tokens = len(ent.split(' '))
idx = tokens_text.index(ent)
@ -208,12 +208,9 @@ class BaseEntityDisambiguator(AbstractEntityDisambiguator):
class NaiveEntityDisambiguator(BaseEntityDisambiguator):
def __init__(self, args):
super().__init__(args)
with open(os.path.join(self.args.database_dir, 'es_material/alias2type.json'), 'r') as fin:
# alias2type.json is a big file (>4G); load it only when necessary
self.alias2type = ujson.load(fin)
all_aliases = marisa_trie.Trie(self.alias2type.keys())
self.all_aliases = all_aliases
self.alias2type = marisa_trie.RecordTrie("<p").mmap(
os.path.join(self.args.database_dir, 'es_material/alias2typeqid.marisa')
)
def find_type_ids(self, tokens, answer=None):
tokens_type_ids = self.lookup(
@ -225,17 +222,13 @@ class NaiveEntityDisambiguator(BaseEntityDisambiguator):
class EntityOracleEntityDisambiguator(BaseEntityDisambiguator):
def __init__(self, args):
super().__init__(args)
with open(os.path.join(self.args.database_dir, 'es_material/alias2type.json'), 'r') as fin:
# alias2type.json is a big file (>4G); load it only when necessary
self.alias2type = ujson.load(fin)
all_aliases = marisa_trie.Trie(self.alias2type.keys())
self.all_aliases = all_aliases
self.alias2type = marisa_trie.RecordTrie("<p").mmap(
os.path.join(self.args.database_dir, 'es_material/alias2typeqid.marisa')
)
def find_type_ids(self, tokens, answer):
answer_entities = quoted_pattern_with_space.findall(answer)
tokens_type_ids = self.lookup_entities(tokens, answer_entities)
return tokens_type_ids
@ -280,11 +273,9 @@ class TypeOracleEntityDisambiguator(BaseEntityDisambiguator):
# this is usually caused by paraphrasing where it adds "-" after entity name: "korean-style restaurants"
# or add "'" before or after an entity
if not re.search(rf'(^|\s){ent}($|\s)', sentence):
# print(f'***ent: {ent} {tokens} {answer}')
continue
# ** this should change if thingtalk syntax changes **
# ( ... [Book|Music|...] ( ) filter id =~ " position and affirm " ) ...'
# ... ^^org.schema.Book:Person ( " james clavell " ) ...
# ... contains~ ( [award|genre|...] , " booker " ) ...

Binary file not shown.

View File

@ -1 +0,0 @@
{"dahr el ouahch":"Q8502","sahlet el ouili":"Q39816","abeking":"Q4167410","substrate profiling of tobacco etch virus protease using a novel fluorescence-assisted whole-cell assay":"Q13442814","culture conversion among hiv co-infected multidrug-resistant tuberculosis patients in tugela ferry , south africa":"Q13442814","characterization of burkholderia rhizoxinica and b. endofungorum isolated from clinical specimens":"Q13442814","increasing incidence of geomyces destructans fungus in bats from the czech republic and slovakia":"Q13442814","molecular mapping of movement-associated areas in the avian brain : a motor theory for vocal learning origin":"Q13442814","sar el adra":"Q8502","cell lineage of the ilyanassa embryo : evolutionary acceleration of regional differentiation during early development":"Q13442814","a survey of honey bee colony losses in the u.s. , fall 2007 to spring 2008":"Q13442814","mitochondrial dna haplogroup d4a is a marker for extreme longevity in japan":"Q13442814","cost-effectiveness of newborn circumcision in reducing lifetime hiv risk among u.s. males":"Q13442814","charabech":"Q8502","chiaab rizq":"Q8502"}

Binary file not shown.