mirror of https://github.com/explosion/spaCy.git
Improve ud-train script. Make config optional
This commit is contained in:
parent
3e3a309764
commit
59cf533879
|
@ -13,7 +13,7 @@ import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from ..tokens import Token, Doc
|
from ..tokens import Token, Doc
|
||||||
from ..gold import GoldParse
|
from ..gold import GoldParse
|
||||||
from ..util import compounding, minibatch_by_words
|
from ..util import compounding, minibatch, minibatch_by_words
|
||||||
from ..syntax.nonproj import projectivize
|
from ..syntax.nonproj import projectivize
|
||||||
from ..matcher import Matcher
|
from ..matcher import Matcher
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
|
@ -302,8 +302,8 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
class Config(object):
|
class Config(object):
|
||||||
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
||||||
multitask_sent=True, multitask_dep=True, multitask_vectors=False,
|
multitask_sent=True, multitask_dep=True, multitask_vectors=False,
|
||||||
nr_epoch=30, batch_size=1000, dropout=0.2,
|
nr_epoch=30, min_batch_size=1, max_batch_size=16, batch_by_words=False,
|
||||||
conv_depth=4, subword_features=True):
|
dropout=0.2, conv_depth=4, subword_features=True):
|
||||||
for key, value in locals().items():
|
for key, value in locals().items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
@ -346,20 +346,23 @@ class TreebankPaths(object):
|
||||||
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||||
"positional", None, str),
|
"positional", None, str),
|
||||||
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||||
config=("Path to json formatted config file", "positional"),
|
config=("Path to json formatted config file", "option", "C", Path),
|
||||||
limit=("Size limit", "option", "n", int),
|
limit=("Size limit", "option", "n", int),
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
use_oracle_segments=("Use oracle segments", "flag", "G", int),
|
use_oracle_segments=("Use oracle segments", "flag", "G", int),
|
||||||
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
||||||
"option", "v", Path),
|
"option", "v", Path),
|
||||||
)
|
)
|
||||||
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None,
|
def main(ud_dir, parses_dir, config=None, corpus, limit=0, use_gpu=-1, vectors_dir=None,
|
||||||
use_oracle_segments=False):
|
use_oracle_segments=False):
|
||||||
spacy.util.fix_random_seed()
|
spacy.util.fix_random_seed()
|
||||||
lang.zh.Chinese.Defaults.use_jieba = False
|
lang.zh.Chinese.Defaults.use_jieba = False
|
||||||
lang.ja.Japanese.Defaults.use_janome = False
|
lang.ja.Japanese.Defaults.use_janome = False
|
||||||
|
|
||||||
config = Config.load(config)
|
if config is not None:
|
||||||
|
config = Config.load(config)
|
||||||
|
else:
|
||||||
|
config = Config()
|
||||||
paths = TreebankPaths(ud_dir, corpus)
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
if not (parses_dir / corpus).exists():
|
if not (parses_dir / corpus).exists():
|
||||||
(parses_dir / corpus).mkdir()
|
(parses_dir / corpus).mkdir()
|
||||||
|
@ -372,7 +375,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||||
|
|
||||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
batch_sizes = compounding(config.min_batch_size, config.max_batch_size, 1.001)
|
||||||
beam_prob = compounding(0.2, 0.8, 1.001)
|
beam_prob = compounding(0.2, 0.8, 1.001)
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
|
@ -381,7 +384,10 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
||||||
raw_text=not use_oracle_segments)
|
raw_text=not use_oracle_segments)
|
||||||
Xs = list(zip(docs, golds))
|
Xs = list(zip(docs, golds))
|
||||||
random.shuffle(Xs)
|
random.shuffle(Xs)
|
||||||
batches = minibatch_by_words(Xs, size=batch_sizes)
|
if config.batch_by_words:
|
||||||
|
batches = minibatch_by_words(Xs, size=batch_sizes)
|
||||||
|
else:
|
||||||
|
batches = minibatch(Xs, size=batch_sizes)
|
||||||
losses = {}
|
losses = {}
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
|
|
Loading…
Reference in New Issue