diff --git a/spacy/_ml.py b/spacy/_ml.py index 4dbc7cb92..eb751ca6c 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -29,6 +29,7 @@ from . import util try: import torch.nn + from thinc.extra.wrappers import PyTorchWrapperRNN except: torch = None @@ -252,7 +253,7 @@ def link_vectors_to_models(vocab): def PyTorchBiLSTM(nO, nI, depth, dropout=0.2): if depth == 0: - return noop() + return layerize(noop()) model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout) return with_square_sequences(PyTorchWrapperRNN(model)) @@ -299,7 +300,6 @@ def Tok2Vec(width, embed_size, **kwargs): ExtractWindow(nW=1) >> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces)) ) - tok2vec = ( FeatureExtracter(cols) >> with_flatten( diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index 8244b1b7f..4c0b3c7eb 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -32,6 +32,11 @@ from .. import lang from ..lang import zh from ..lang import ja +try: + import torch +except ImportError: + torch = None + ################ # Data reading # @@ -207,6 +212,14 @@ def write_conllu(docs, file_): file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) file_.write("# text = {text}\n".format(text=sent.text)) for k, token in enumerate(sent): + if token.head.i > sent[-1].i or token.head.i < sent[0].i: + for word in doc[sent[0].i-10 : sent[0].i]: + print(word.i, word.head.i, word.text, word.dep_) + for word in sent: + print(word.i, word.head.i, word.text, word.dep_) + for word in doc[sent[-1].i : sent[-1].i+10]: + print(word.i, word.head.i, word.text, word.dep_) + raise ValueError("Invalid parse: head outside sentence (%s)" % token.text) file_.write(token._.get_conllu_lines(k) + '\n') file_.write('\n') @@ -290,9 +303,12 @@ def initialize_pipeline(nlp, docs, golds, config, device): for tag in gold.tags: if tag is not None: nlp.tagger.add_label(tag) + if torch is not None and device != -1: + torch.set_default_tensor_type('torch.cuda.FloatTensor') return nlp.begin_training( lambda: golds_to_gold_tuples(docs, golds), device=device, - subword_features=config.subword_features, conv_depth=config.conv_depth) + subword_features=config.subword_features, conv_depth=config.conv_depth, + bilstm_depth=config.bilstm_depth) ######################## @@ -356,12 +372,12 @@ class TreebankPaths(object): parses_dir=("Directory to write the development parses", "positional", None, Path), config=("Path to json formatted config file", "option", "C", Path), limit=("Size limit", "option", "n", int), - use_gpu=("Use GPU", "option", "g", int), + gpu_device=("Use GPU", "option", "g", int), use_oracle_segments=("Use oracle segments", "flag", "G", int), vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/", "option", "v", Path), ) -def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_dir=None, +def main(ud_dir, parses_dir, corpus, config=None, limit=0, gpu_device=-1, vectors_dir=None, use_oracle_segments=False): spacy.util.fix_random_seed() lang.zh.Chinese.Defaults.use_jieba = False @@ -381,7 +397,7 @@ def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_d max_doc_length=config.max_doc_length, limit=limit) - optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu) + optimizer = initialize_pipeline(nlp, docs, golds, config, gpu_device) batch_sizes = compounding(config.min_batch_size, config.max_batch_size, 1.001) beam_prob = compounding(0.2, 0.8, 1.001) @@ -415,7 +431,6 @@ def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_d parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path) print_progress(i, losses, scores) - _render_parses(i, parsed_docs[:50]) def _render_parses(i, to_render): diff --git a/spacy/language.py b/spacy/language.py index a993f7eb3..e64768d05 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -426,6 +426,10 @@ class Language(object): def get_grads(W, dW, key=None): grads[key] = (W, dW) + get_grads.alpha = sgd.alpha + get_grads.b1 = sgd.b1 + get_grads.b2 = sgd.b2 + pipes = list(self.pipeline) random.shuffle(pipes) for name, proc in pipes: