diff --git a/spacy/__main__.py b/spacy/__main__.py index cf959def7..acce3b7c8 100644 --- a/spacy/__main__.py +++ b/spacy/__main__.py @@ -83,18 +83,20 @@ class CLI(object): n_iter=("number of iterations", "option", "n", int), nsents=("number of sentences", "option", None, int), parser_L1=("L1 regularization penalty for parser", "option", "L", float), + use_gpu=("Use GPU", "flag", "g", bool), no_tagger=("Don't train tagger", "flag", "T", bool), no_parser=("Don't train parser", "flag", "P", bool), no_ner=("Don't train NER", "flag", "N", bool) ) def train(self, lang, output_dir, train_data, dev_data=None, n_iter=15, - nsents=0, parser_L1=0.0, no_tagger=False, no_parser=False, no_ner=False): + nsents=0, parser_L1=0.0, use_gpu=False, + no_tagger=False, no_parser=False, no_ner=False): """ Train a model. Expects data in spaCy's JSON format. """ nsents = nsents or None - cli_train(lang, output_dir, train_data, dev_data, n_iter, nsents, not no_tagger, - not no_parser, not no_ner, parser_L1) + cli_train(lang, output_dir, train_data, dev_data, n_iter, nsents, + use_gpu, not no_tagger, not no_parser, not no_ner, parser_L1) @plac.annotations( lang=("model language", "positional", None, str),