fix bug in iter
This commit is contained in:
parent
0c68277dfe
commit
36b0fb5317
|
@ -24,6 +24,7 @@ workdir/
|
|||
*save*/
|
||||
tests/dataset-*
|
||||
tests/dataset/*
|
||||
tests/test_py.sh
|
||||
|
||||
|
||||
# C extensions
|
||||
|
|
|
@ -108,10 +108,7 @@ class AlmondDataset(CQA):
|
|||
|
||||
|
||||
def is_entity(token):
|
||||
try:
|
||||
return token[0].isupper()
|
||||
except:
|
||||
print('here')
|
||||
return token[0].isupper()
|
||||
|
||||
def process_id(ex):
|
||||
id_ = ex.example_id.rsplit('/', 1)
|
||||
|
@ -377,12 +374,14 @@ class AlmondMultiLingual(BaseAlmondTask):
|
|||
sort_key_fn = context_answer_len
|
||||
batch_size_fn = token_batch_fn
|
||||
|
||||
groups = len(all_datasets) if kwargs.get('sentence_batching') else None
|
||||
|
||||
if kwargs.get('separate_eval') and (all_datasets[0].eval or all_datasets[0].test):
|
||||
return all_datasets
|
||||
else:
|
||||
return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields)
|
||||
return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields, groups)
|
||||
|
||||
def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields):
|
||||
def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields, groups):
|
||||
splits = defaultdict()
|
||||
|
||||
for field in used_fields:
|
||||
|
@ -390,7 +389,7 @@ class AlmondMultiLingual(BaseAlmondTask):
|
|||
for dataset in datasets:
|
||||
all_examples.extend(getattr(dataset, field).examples)
|
||||
|
||||
splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=len(datasets))
|
||||
splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=groups)
|
||||
|
||||
return Split(train=splits.get('train'),
|
||||
eval=splits.get('eval'),
|
||||
|
|
|
@ -246,12 +246,11 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal
|
|||
repeat=train,
|
||||
use_data_batch_fn=train,
|
||||
use_data_sort_key=train)
|
||||
|
||||
collate_function = lambda minibatch: Batch.from_examples(minibatch, numericalizer, device=device,
|
||||
paired=paired and train, max_pairs=max_pairs, groups=iterator.groups)
|
||||
|
||||
return torch.utils.data.DataLoader(iterator,
|
||||
batch_size=None,
|
||||
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
|
||||
device=device, paired=paired and train,
|
||||
max_pairs=max_pairs, groups=iterator.groups))
|
||||
return torch.utils.data.DataLoader(iterator, batch_size=None, collate_fn=collate_function)
|
||||
|
||||
|
||||
def pad(x, new_channel, dim, val=None):
|
||||
|
|
|
@ -39,7 +39,7 @@ for hparams in \
|
|||
do
|
||||
|
||||
# train
|
||||
pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit
|
||||
pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
|
||||
|
||||
# greedy decode
|
||||
pipenv run python3 -m genienlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir
|
||||
|
@ -63,7 +63,7 @@ for hparams in \
|
|||
do
|
||||
|
||||
# train
|
||||
pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit
|
||||
pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
|
||||
|
||||
# greedy decode
|
||||
# combined evaluation
|
||||
|
|
Loading…
Reference in New Issue