Support GPU in UD training script

This commit is contained in:
Matthew Honnibal 2018-03-27 09:53:35 +00:00
parent dd54511c4f
commit 8bbd26579c
1 changed files with 8 additions and 9 deletions

View File

@ -254,7 +254,7 @@ def load_nlp(corpus, config):
nlp.vocab.from_disk(Path(config.vectors) / 'vocab') nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
return nlp return nlp
def initialize_pipeline(nlp, docs, golds, config): def initialize_pipeline(nlp, docs, golds, config, device):
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag: if config.multitask_tag:
nlp.parser.add_multitask_objective('tag') nlp.parser.add_multitask_objective('tag')
@ -265,7 +265,7 @@ def initialize_pipeline(nlp, docs, golds, config):
for tag in gold.tags: for tag in gold.tags:
if tag is not None: if tag is not None:
nlp.tagger.add_label(tag) nlp.tagger.add_label(tag)
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
######################## ########################
@ -318,15 +318,14 @@ class TreebankPaths(object):
"positional", None, str), "positional", None, str),
parses_dir=("Directory to write the development parses", "positional", None, Path), parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional"), config=("Path to json formatted config file", "positional"),
limit=("Size limit", "option", "n", int) limit=("Size limit", "option", "n", int),
use_gpu=("Use GPU", "option", "g", int)
) )
def main(ud_dir, parses_dir, config, corpus, limit=0): def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False lang.ja.Japanese.Defaults.use_janome = False
random.seed(0)
numpy.random.seed(0)
config = Config.load(config) config = Config.load(config)
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists(): if not (parses_dir / corpus).exists():
@ -337,7 +336,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
max_doc_length=config.max_doc_length, limit=limit) max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config) optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001) batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
for i in range(config.nr_epoch): for i in range(config.nr_epoch):