fix bug in iter

This commit is contained in:
mehrad 2020-04-19 18:03:11 -07:00
parent 0c68277dfe
commit 36b0fb5317
4 changed files with 13 additions and 14 deletions

1
.gitignore vendored
View File

@ -24,6 +24,7 @@ workdir/
*save*/
tests/dataset-*
tests/dataset/*
tests/test_py.sh
# C extensions

View File

@ -108,10 +108,7 @@ class AlmondDataset(CQA):
def is_entity(token):
try:
return token[0].isupper()
except:
print('here')
return token[0].isupper()
def process_id(ex):
id_ = ex.example_id.rsplit('/', 1)
@ -377,12 +374,14 @@ class AlmondMultiLingual(BaseAlmondTask):
sort_key_fn = context_answer_len
batch_size_fn = token_batch_fn
groups = len(all_datasets) if kwargs.get('sentence_batching') else None
if kwargs.get('separate_eval') and (all_datasets[0].eval or all_datasets[0].test):
return all_datasets
else:
return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields)
return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields, groups)
def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields):
def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields, groups):
splits = defaultdict()
for field in used_fields:
@ -390,7 +389,7 @@ class AlmondMultiLingual(BaseAlmondTask):
for dataset in datasets:
all_examples.extend(getattr(dataset, field).examples)
splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=len(datasets))
splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=groups)
return Split(train=splits.get('train'),
eval=splits.get('eval'),

View File

@ -246,12 +246,11 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal
repeat=train,
use_data_batch_fn=train,
use_data_sort_key=train)
collate_function = lambda minibatch: Batch.from_examples(minibatch, numericalizer, device=device,
paired=paired and train, max_pairs=max_pairs, groups=iterator.groups)
return torch.utils.data.DataLoader(iterator,
batch_size=None,
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
device=device, paired=paired and train,
max_pairs=max_pairs, groups=iterator.groups))
return torch.utils.data.DataLoader(iterator, batch_size=None, collate_fn=collate_function)
def pad(x, new_channel, dim, val=None):

View File

@ -39,7 +39,7 @@ for hparams in \
do
# train
pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit
pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
# greedy decode
pipenv run python3 -m genienlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir
@ -63,7 +63,7 @@ for hparams in \
do
# train
pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit
pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
# greedy decode
# combined evaluation