mirror of https://github.com/explosion/spaCy.git
Remove unused model_dir option
As noted in #845, the `model_dir` argument was not being used. I've removed it for now, although it would be good to have this option restored and working.
This commit is contained in:
parent
724e51ed47
commit
c031c677cc
|
@ -18,7 +18,7 @@ except ImportError:
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
def train(model_dir, train_loc, dev_loc, shape, settings):
|
def train(train_loc, dev_loc, shape, settings):
|
||||||
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
||||||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||||
batch_size=settings['batch_size'])
|
batch_size=settings['batch_size'])
|
||||||
if not (nlp.path / 'similarity').exists():
|
if not (nlp.path / 'similarity').exists():
|
||||||
(nlp.path / 'similarity').mkdir()
|
(nlp.path / 'similarity').mkdir()
|
||||||
print("Saving to", model_dir / 'similarity')
|
print("Saving to", nlp.path / 'similarity')
|
||||||
weights = model.get_weights()
|
weights = model.get_weights()
|
||||||
with (nlp.path / 'similarity' / 'model').open('wb') as file_:
|
with (nlp.path / 'similarity' / 'model').open('wb') as file_:
|
||||||
pickle.dump(weights[1:], file_)
|
pickle.dump(weights[1:], file_)
|
||||||
|
@ -68,8 +68,8 @@ def evaluate(model_dir, dev_loc):
|
||||||
return correct, total
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
def demo(model_dir):
|
def demo():
|
||||||
nlp = spacy.load('en', path=model_dir,
|
nlp = spacy.load('en',
|
||||||
create_pipeline=create_similarity_pipeline)
|
create_pipeline=create_similarity_pipeline)
|
||||||
doc1 = nlp(u'What were the best crime fiction books in 2016?')
|
doc1 = nlp(u'What were the best crime fiction books in 2016?')
|
||||||
doc2 = nlp(
|
doc2 = nlp(
|
||||||
|
@ -98,7 +98,6 @@ def read_snli(path):
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
|
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
|
||||||
model_dir=("Path to spaCy model directory", "positional", None, Path),
|
|
||||||
train_loc=("Path to training data", "positional", None, Path),
|
train_loc=("Path to training data", "positional", None, Path),
|
||||||
dev_loc=("Path to development data", "positional", None, Path),
|
dev_loc=("Path to development data", "positional", None, Path),
|
||||||
max_length=("Length to truncate sentences", "option", "L", int),
|
max_length=("Length to truncate sentences", "option", "L", int),
|
||||||
|
@ -110,7 +109,7 @@ def read_snli(path):
|
||||||
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
||||||
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
||||||
)
|
)
|
||||||
def main(mode, model_dir, train_loc, dev_loc,
|
def main(mode, train_loc, dev_loc,
|
||||||
tree_truncate=False,
|
tree_truncate=False,
|
||||||
gru_encode=False,
|
gru_encode=False,
|
||||||
max_length=100,
|
max_length=100,
|
||||||
|
@ -129,12 +128,12 @@ def main(mode, model_dir, train_loc, dev_loc,
|
||||||
'gru_encode': gru_encode
|
'gru_encode': gru_encode
|
||||||
}
|
}
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
train(model_dir, train_loc, dev_loc, shape, settings)
|
train(train_loc, dev_loc, shape, settings)
|
||||||
elif mode == 'evaluate':
|
elif mode == 'evaluate':
|
||||||
correct, total = evaluate(model_dir, dev_loc)
|
correct, total = evaluate(dev_loc)
|
||||||
print(correct, '/', total, correct / total)
|
print(correct, '/', total, correct / total)
|
||||||
else:
|
else:
|
||||||
demo(model_dir)
|
demo()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue