Refactor Language to use new Defaults class, and work on revised data loading. We're getting rid of sputnik's weird file-system wrapper, and using pathlib.

This commit is contained in:
Matthew Honnibal 2016-09-24 14:08:53 +02:00
parent b00f683a0c
commit 9dc8043a7e
1 changed files with 131 additions and 137 deletions

View File

@ -1,31 +1,108 @@
from __future__ import absolute_import from __future__ import absolute_import
from os import path
from warnings import warn from warnings import warn
import io import pathlib
try: try:
import ujson as json import ujson as json
except ImportError: except ImportError:
import json import json
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .syntax.parser import Parser from .syntax.parser import Parser
from .tagger import Tagger from .tagger import Tagger
from .matcher import Matcher from .matcher import Matcher
from .serialize.packer import Packer
from . import attrs from . import attrs
from . import orth from . import orth
from .syntax.ner import BiluoPushDown from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager from .syntax.arc_eager import ArcEager
from . import util
from . import about
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD
class Language(object): class Defaults(object):
lang = None def __init__(self, lang, path):
self.lang = lang
self.path = path
def Vectors(self):
pass
def Vocab(self, vectors=None, get_lex_attr=None):
if get_lex_attr is None:
get_lex_attr = self.lex_attrs()
if vectors is None:
vectors = self.Vectors()
return Vocab.load(self.path, get_lex_attr=get_lex_attr, vectors=vectors)
def Tokenizer(self, vocab):
return Tokenizer.load(self.path, vocab)
def Tagger(self, vocab):
return Tagger.load(self.path, self.vocab)
def Parser(self, vocab):
if (self.path / 'deps').exists():
return Parser.load(self.path / 'deps', vocab, ArcEager)
else:
return None
def Entity(self, vocab):
if (self.path / 'ner').exists():
return Parser.load(self.path / 'ner', vocab, BiluoPushDown)
else:
return None
def Matcher(self, vocab):
return Matcher.load(self.path, vocab)
def Pipeline(self, nlp):
return [
nlp.tokenizer,
nlp.tagger,
nlp.parser,
nlp.entity]
def dep_labels(self):
return {0: {'ROOT': True}}
def ner_labels(self):
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
def lex_attrs(self, *args, **kwargs):
if 'oov_prob' in kwargs:
oov_prob = kwargs.get('oov_prob', -20)
else:
with (self.path / 'vocab' / 'oov_prob').open() as file_:
oov_prob = file_.read().strip()
return {
attrs.LOWER: self.lower,
attrs.NORM: self.norm,
attrs.SHAPE: orth.word_shape,
attrs.PREFIX: self.prefix,
attrs.SUFFIX: self.suffix,
attrs.CLUSTER: self.cluster,
attrs.PROB: lambda string: oov_prob,
attrs.LANG: lambda string: self.lang,
attrs.IS_ALPHA: orth.is_alpha,
attrs.IS_ASCII: orth.is_ascii,
attrs.IS_DIGIT: self.is_digit,
attrs.IS_LOWER: orth.is_lower,
attrs.IS_PUNCT: orth.is_punct,
attrs.IS_SPACE: self.is_space,
attrs.IS_TITLE: orth.is_title,
attrs.IS_UPPER: orth.is_upper,
attrs.IS_BRACKET: orth.is_bracket,
attrs.IS_QUOTE: orth.is_quote,
attrs.IS_LEFT_PUNCT: orth.is_left_punct,
attrs.IS_RIGHT_PUNCT: orth.is_right_punct,
attrs.LIKE_URL: orth.like_url,
attrs.LIKE_NUM: orth.like_number,
attrs.LIKE_EMAIL: orth.like_email,
attrs.IS_STOP: self.is_stop,
attrs.IS_OOV: lambda string: True
}
@staticmethod @staticmethod
def lower(string): def lower(string):
@ -59,94 +136,27 @@ class Language(object):
def is_stop(string): def is_stop(string):
return 0 return 0
@classmethod
def default_lex_attrs(cls, *args, **kwargs):
oov_prob = kwargs.get('oov_prob', -20)
return {
attrs.LOWER: cls.lower,
attrs.NORM: cls.norm,
attrs.SHAPE: orth.word_shape,
attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster,
attrs.PROB: lambda string: oov_prob,
attrs.LANG: lambda string: cls.lang,
attrs.IS_ALPHA: orth.is_alpha,
attrs.IS_ASCII: orth.is_ascii,
attrs.IS_DIGIT: cls.is_digit,
attrs.IS_LOWER: orth.is_lower,
attrs.IS_PUNCT: orth.is_punct,
attrs.IS_SPACE: cls.is_space,
attrs.IS_TITLE: orth.is_title,
attrs.IS_UPPER: orth.is_upper,
attrs.IS_BRACKET: orth.is_bracket,
attrs.IS_QUOTE: orth.is_quote,
attrs.IS_LEFT_PUNCT: orth.is_left_punct,
attrs.IS_RIGHT_PUNCT: orth.is_right_punct,
attrs.LIKE_URL: orth.like_url,
attrs.LIKE_NUM: orth.like_number,
attrs.LIKE_EMAIL: orth.like_email,
attrs.IS_STOP: cls.is_stop,
attrs.IS_OOV: lambda string: True
}
@classmethod class Language(object):
def default_dep_labels(cls): '''A text-processing pipeline. Usually you'll load this once per process, and
return {0: {'ROOT': True}} pass the instance around your program.
'''
@classmethod lang = None
def default_ner_labels(cls):
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
@classmethod
def default_vocab(cls, package, get_lex_attr=None, vectors_package=None):
if get_lex_attr is None:
if package.has_file('vocab', 'oov_prob'):
with package.open(('vocab', 'oov_prob')) as file_:
oov_prob = float(file_.read().strip())
get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob)
else:
get_lex_attr = cls.default_lex_attrs()
if hasattr(package, 'dir_path'):
return Vocab.from_package(package, get_lex_attr=get_lex_attr,
vectors_package=vectors_package)
else:
return Vocab.load(package, get_lex_attr)
@classmethod
def default_parser(cls, package, vocab):
if hasattr(package, 'dir_path'):
data_dir = package.dir_path('deps')
else:
data_dir = package
if data_dir and path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, ArcEager)
else:
return None
@classmethod
def default_entity(cls, package, vocab):
if hasattr(package, 'dir_path'):
data_dir = package.dir_path('ner')
else:
data_dir = package
if data_dir and path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown)
else:
return None
def __init__(self, def __init__(self,
data_dir=None, path=None,
vocab=None, vocab=True,
tokenizer=None, tokenizer=True,
tagger=None, tagger=True,
parser=None, parser=True,
entity=None, entity=True,
matcher=None, matcher=True,
serializer=None, serializer=True,
load_vectors=True, vectors=True,
package=None, pipeline=True,
vectors_package=None): defaults=True,
data_dir=None):
""" """
A model can be specified: A model can be specified:
@ -165,44 +175,24 @@ class Language(object):
- spacy.load('en_default', via='/my/package/root') - spacy.load('en_default', via='/my/package/root')
- spacy.load('en_default==1.0.0', via='/my/package/root') - spacy.load('en_default==1.0.0', via='/my/package/root')
""" """
if package is None: if data_dir is not None and path is None:
if data_dir is None: warn("'data_dir' argument now named 'path'. Doing what you mean.")
package = util.get_package_by_name(about.__models__[self.lang]) path = data_dir
else: if isinstance(path, basestring):
package = util.get_package(data_dir) path = pathlib.Path(path)
defaults = defaults if defaults is not True else self.get_defaults(self.path)
if load_vectors is not True: self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors)
warn("load_vectors is deprecated", DeprecationWarning) self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab)
self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab)
if vocab in (None, True): self.entity = entity if entity is not True else defaults.Entity(self.vocab)
vocab = self.default_vocab(package, vectors_package=vectors_package) self.parser = parser if parser is not True else defaults.Parser(self.vocab)
self.vocab = vocab self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
if tokenizer in (None, True): self.pipeline = self.pipeline if pipeline is not True else defaults.Pipeline(self)
tokenizer = Tokenizer.from_package(package, self.vocab)
self.tokenizer = tokenizer
if tagger in (None, True):
tagger = Tagger.from_package(package, self.vocab)
self.tagger = tagger
if entity in (None, True):
entity = self.default_entity(package, self.vocab)
self.entity = entity
if parser in (None, True):
parser = self.default_parser(package, self.vocab)
self.parser = parser
if matcher in (None, True):
matcher = Matcher.from_package(package, self.vocab)
self.matcher = matcher
self.pipeline = [
self.tokenizer,
self.tagger,
self.entity,
self.parser,
self.matcher
]
def __reduce__(self): def __reduce__(self):
args = ( args = (
None, # data_dir self.path,
self.vocab, self.vocab,
self.tokenizer, self.tokenizer,
self.tagger, self.tagger,
@ -255,23 +245,23 @@ class Language(object):
for doc in stream: for doc in stream:
yield doc yield doc
def end_training(self, data_dir=None): def end_training(self, path=None):
if data_dir is None: if path is None:
data_dir = self.data_dir path = self.path
if self.parser: if self.parser:
self.parser.model.end_training() self.parser.model.end_training()
self.parser.model.dump(path.join(data_dir, 'deps', 'model')) self.parser.model.dump(path / 'deps' / 'model')
if self.entity: if self.entity:
self.entity.model.end_training() self.entity.model.end_training()
self.entity.model.dump(path.join(data_dir, 'ner', 'model')) self.entity.model.dump(path / 'ner' / 'model')
if self.tagger: if self.tagger:
self.tagger.model.end_training() self.tagger.model.end_training()
self.tagger.model.dump(path.join(data_dir, 'pos', 'model')) self.tagger.model.dump(path / 'pos' / 'model')
strings_loc = path.join(data_dir, 'vocab', 'strings.json') strings_loc = path / 'vocab' / 'strings.json'
with io.open(strings_loc, 'w', encoding='utf8') as file_: with strings_loc.open('w', encoding='utf8') as file_:
self.vocab.strings.dump(file_) self.vocab.strings.dump(file_)
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin')) self.vocab.dump(path / 'vocab' / 'lexemes.bin')
if self.tagger: if self.tagger:
tagger_freqs = list(self.tagger.freqs[TAG].items()) tagger_freqs = list(self.tagger.freqs[TAG].items())
@ -289,7 +279,7 @@ class Language(object):
else: else:
entity_iob_freqs = [] entity_iob_freqs = []
entity_type_freqs = [] entity_type_freqs = []
with open(path.join(data_dir, 'vocab', 'serializer.json'), 'w') as file_: with (path / 'vocab' / 'serializer.json').open('w') as file_:
file_.write( file_.write(
json.dumps([ json.dumps([
(TAG, tagger_freqs), (TAG, tagger_freqs),
@ -298,3 +288,7 @@ class Language(object):
(ENT_TYPE, entity_type_freqs), (ENT_TYPE, entity_type_freqs),
(HEAD, head_freqs) (HEAD, head_freqs)
])) ]))
def get_defaults(self, path):
return Defaults(path)