Improve train tensorizer script

This commit is contained in:
Matthew Honnibal 2018-11-03 10:54:20 +00:00
parent ba365ae1c9
commit 0127f10ba3
1 changed files with 14 additions and 7 deletions

View File

@ -2,7 +2,7 @@
import plac import plac
import spacy import spacy
import thinc.extra.datasets import thinc.extra.datasets
from spacy.util import minibatch from spacy.util import minibatch, use_gpu
import tqdm import tqdm
@ -12,7 +12,7 @@ def load_imdb():
train_texts, _ = zip(*train) train_texts, _ = zip(*train)
dev_texts, _ = zip(*dev) dev_texts, _ = zip(*dev)
nlp.add_pipe(nlp.create_pipe('sentencizer')) nlp.add_pipe(nlp.create_pipe('sentencizer'))
return list(get_sentences(nlp, train_texts)), list(get_sentences(nlp, dev_texts)) return list(train_texts), list(dev_texts)
def get_sentences(nlp, texts): def get_sentences(nlp, texts):
@ -21,12 +21,20 @@ def get_sentences(nlp, texts):
yield sent.text yield sent.text
def main(): def prefer_gpu():
used = spacy.util.use_gpu(0)
if used is None:
return False
else:
return True
def main(vectors_model):
use_gpu = prefer_gpu()
print("Using GPU?", use_gpu)
print("Load data") print("Load data")
train_texts, dev_texts = load_imdb() train_texts, dev_texts = load_imdb()
train_texts = train_texts[:1000]
print("Load vectors") print("Load vectors")
nlp = spacy.load('en_vectors_web_lg') nlp = spacy.load(vectors_model)
print("Start training") print("Start training")
nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('tagger'))
tensorizer = nlp.create_pipe('tensorizer') tensorizer = nlp.create_pipe('tensorizer')
@ -38,8 +46,7 @@ def main():
for i, batch in enumerate(minibatch(tqdm.tqdm(train_texts))): for i, batch in enumerate(minibatch(tqdm.tqdm(train_texts))):
docs = [nlp.make_doc(text) for text in batch] docs = [nlp.make_doc(text) for text in batch]
tensorizer.update(docs, None, losses=losses, sgd=optimizer, drop=0.5) tensorizer.update(docs, None, losses=losses, sgd=optimizer, drop=0.5)
if i % 10 == 0: print(losses)
print(losses)
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)