mirror of https://github.com/explosion/spaCy.git
fix for Issue #4000
This commit is contained in:
parent
dae8a21282
commit
9f8c1e71a2
|
@ -1,6 +1,8 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
from os import path
|
||||||
import random
|
import random
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -26,7 +28,8 @@ ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
||||||
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||||
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||||
|
|
||||||
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
|
KB_DIR = OUTPUT_DIR / "kb_1"
|
||||||
|
KB_FILE = "kb"
|
||||||
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||||
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||||
|
|
||||||
|
@ -118,7 +121,10 @@ def run_pipeline():
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("STEP 3b: write KB and NLP", now())
|
print("STEP 3b: write KB and NLP", now())
|
||||||
kb_1.dump(KB_FILE)
|
|
||||||
|
if not path.exists(KB_DIR):
|
||||||
|
os.makedirs(KB_DIR)
|
||||||
|
kb_1.dump(KB_DIR / KB_FILE)
|
||||||
nlp_1.to_disk(NLP_1_DIR)
|
nlp_1.to_disk(NLP_1_DIR)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
@ -127,7 +133,7 @@ def run_pipeline():
|
||||||
print("STEP 4: to_read_kb", now())
|
print("STEP 4: to_read_kb", now())
|
||||||
nlp_2 = spacy.load(NLP_1_DIR)
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
kb_2.load_bulk(KB_FILE)
|
kb_2.load_bulk(KB_DIR / KB_FILE)
|
||||||
print("kb entities:", kb_2.get_size_entities())
|
print("kb entities:", kb_2.get_size_entities())
|
||||||
print("kb aliases:", kb_2.get_size_aliases())
|
print("kb aliases:", kb_2.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
@ -327,7 +333,8 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||||
# only evaluating on positive examples
|
# only evaluating on positive examples
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
if value:
|
if value:
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
offset = str(start) + "-" + str(end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
|
@ -385,7 +392,8 @@ def _measure_baselines(data, kb):
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
# only evaluating on positive examples
|
# only evaluating on positive examples
|
||||||
if value:
|
if value:
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
offset = str(start) + "-" + str(end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
label = ent.label_
|
label = ent.label_
|
||||||
|
|
|
@ -278,7 +278,7 @@ cdef class KnowledgeBase:
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
cdef hash_t alias_hash
|
cdef hash_t alias_hash
|
||||||
cdef int64_t entry_index
|
cdef int64_t entry_index
|
||||||
cdef float freq
|
cdef float freq, prob
|
||||||
cdef int32_t vector_index
|
cdef int32_t vector_index
|
||||||
cdef KBEntryC entry
|
cdef KBEntryC entry
|
||||||
cdef AliasC alias
|
cdef AliasC alias
|
||||||
|
@ -373,7 +373,7 @@ cdef class Writer:
|
||||||
loc = bytes(loc)
|
loc = bytes(loc)
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||||
assert self._fp != NULL
|
assert self._fp != NULL, "Could not access %s" % loc
|
||||||
fseek(self._fp, 0, 0)
|
fseek(self._fp, 0, 0)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
Loading…
Reference in New Issue