From d17546681ce01c3b7ce6e36a3271b149a464774b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 20 Oct 2016 03:21:56 +0200 Subject: [PATCH] Fix deep learning tutorial --- examples/deep_learning_keras.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/deep_learning_keras.py b/examples/deep_learning_keras.py index fc672da75..eb247561e 100644 --- a/examples/deep_learning_keras.py +++ b/examples/deep_learning_keras.py @@ -4,9 +4,8 @@ import random import cytoolz import numpy -from keras.layers import Sequential, LSTM, Dense, Embedding, Dropout -from keras.wrappers import Bidirectional -from keras import model_from_json +from keras.models import Sequential, model_from_json +from keras.layers import LSTM, Dense, Embedding, Dropout, Bidirectional import cPickle as pickle import spacy @@ -127,17 +126,16 @@ def read_data(data_dir, limit=0): @plac.annotations( - language=("The language to train", "positional", None, str, ['en','de', 'zh']), - train_loc=("Location of training file or directory"), - dev_loc=("Location of development file or directory"), + train_dir=("Location of training file or directory"), + dev_dir=("Location of development file or directory"), model_dir=("Location of output model directory",), is_runtime=("Demonstrate run-time usage", "flag", "r", bool), - nr_hidden=("Number of hidden units", "flag", "H", int), - max_length=("Maximum sentence length", "flag", "L", int), - dropout=("Dropout", "flag", "d", float), - nr_epoch=("Number of training epochs", "flag", "i", int), - batch_size=("Size of minibatches for training LSTM", "flag", "b", int), - nr_examples=("Limit to N examples", "flag", "n", int) + nr_hidden=("Number of hidden units", "option", "H", int), + max_length=("Maximum sentence length", "option", "L", int), + dropout=("Dropout", "option", "d", float), + nb_epoch=("Number of training epochs", "option", "i", int), + batch_size=("Size of minibatches for training LSTM", "option", "b", int), + nr_examples=("Limit to N examples", "option", "n", int) ) def main(model_dir, train_dir, dev_dir, is_runtime=False,