Add document length cap for training

This commit is contained in:
Matthew Honnibal 2017-11-03 01:54:54 +01:00
parent 6771780d3f
commit c2bbf076a4
1 changed files with 4 additions and 0 deletions

View File

@ -85,6 +85,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
batch_sizes = util.compounding(util.env_opt('batch_from', 1), batch_sizes = util.compounding(util.env_opt('batch_from', 1),
util.env_opt('batch_to', 16), util.env_opt('batch_to', 16),
util.env_opt('batch_compound', 1.001)) util.env_opt('batch_compound', 1.001))
max_doc_len = util.env_opt('max_doc_len', 5000)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents) corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_words = corpus.count_train() n_train_words = corpus.count_train()
@ -108,6 +109,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {} losses = {}
for batch in minibatch(train_docs, size=batch_sizes): for batch in minibatch(train_docs, size=batch_sizes):
batch = [(d, g) for (d, g) in batch if len(d) < max_doc_len]
if not batch:
continue
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp.update(docs, golds, sgd=optimizer, nlp.update(docs, golds, sgd=optimizer,
drop=next(dropout_rates), losses=losses) drop=next(dropout_rates), losses=losses)