mirror of https://github.com/explosion/spaCy.git
Support GPU in UD training script
This commit is contained in:
parent
dd54511c4f
commit
8bbd26579c
|
@ -254,7 +254,7 @@ def load_nlp(corpus, config):
|
||||||
nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
|
nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
def initialize_pipeline(nlp, docs, golds, config):
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
if config.multitask_tag:
|
if config.multitask_tag:
|
||||||
nlp.parser.add_multitask_objective('tag')
|
nlp.parser.add_multitask_objective('tag')
|
||||||
|
@ -265,7 +265,7 @@ def initialize_pipeline(nlp, docs, golds, config):
|
||||||
for tag in gold.tags:
|
for tag in gold.tags:
|
||||||
if tag is not None:
|
if tag is not None:
|
||||||
nlp.tagger.add_label(tag)
|
nlp.tagger.add_label(tag)
|
||||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
||||||
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
|
@ -318,15 +318,14 @@ class TreebankPaths(object):
|
||||||
"positional", None, str),
|
"positional", None, str),
|
||||||
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||||
config=("Path to json formatted config file", "positional"),
|
config=("Path to json formatted config file", "positional"),
|
||||||
limit=("Size limit", "option", "n", int)
|
limit=("Size limit", "option", "n", int),
|
||||||
|
use_gpu=("Use GPU", "option", "g", int)
|
||||||
)
|
)
|
||||||
def main(ud_dir, parses_dir, config, corpus, limit=0):
|
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
||||||
|
spacy.util.fix_random_seed()
|
||||||
lang.zh.Chinese.Defaults.use_jieba = False
|
lang.zh.Chinese.Defaults.use_jieba = False
|
||||||
lang.ja.Japanese.Defaults.use_janome = False
|
lang.ja.Japanese.Defaults.use_janome = False
|
||||||
|
|
||||||
random.seed(0)
|
|
||||||
numpy.random.seed(0)
|
|
||||||
|
|
||||||
config = Config.load(config)
|
config = Config.load(config)
|
||||||
paths = TreebankPaths(ud_dir, corpus)
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
if not (parses_dir / corpus).exists():
|
if not (parses_dir / corpus).exists():
|
||||||
|
@ -337,7 +336,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
max_doc_length=config.max_doc_length, limit=limit)
|
max_doc_length=config.max_doc_length, limit=limit)
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||||
|
|
||||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
|
|
Loading…
Reference in New Issue