diff --git a/.gitignore b/.gitignore index ead4f562..2d61e75f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ workdir/ *save*/ tests/dataset-* tests/dataset/* +tests/test_py.sh # C extensions diff --git a/genienlp/tasks/almond/__init__.py b/genienlp/tasks/almond/__init__.py index baa74a6d..49d2ef85 100644 --- a/genienlp/tasks/almond/__init__.py +++ b/genienlp/tasks/almond/__init__.py @@ -108,10 +108,7 @@ class AlmondDataset(CQA): def is_entity(token): - try: - return token[0].isupper() - except: - print('here') + return token[0].isupper() def process_id(ex): id_ = ex.example_id.rsplit('/', 1) @@ -377,12 +374,14 @@ class AlmondMultiLingual(BaseAlmondTask): sort_key_fn = context_answer_len batch_size_fn = token_batch_fn + groups = len(all_datasets) if kwargs.get('sentence_batching') else None + if kwargs.get('separate_eval') and (all_datasets[0].eval or all_datasets[0].test): return all_datasets else: - return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields) + return self.combine_datasets(all_datasets, sort_key_fn, batch_size_fn, used_fields, groups) - def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields): + def combine_datasets(self, datasets, sort_key_fn, batch_size_fn, used_fields, groups): splits = defaultdict() for field in used_fields: @@ -390,7 +389,7 @@ class AlmondMultiLingual(BaseAlmondTask): for dataset in datasets: all_examples.extend(getattr(dataset, field).examples) - splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=len(datasets)) + splits[field] = CQA(all_examples, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, groups=groups) return Split(train=splits.get('train'), eval=splits.get('eval'), diff --git a/genienlp/util.py b/genienlp/util.py index f53aac99..ee840c42 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -246,12 +246,11 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal repeat=train, use_data_batch_fn=train, use_data_sort_key=train) + + collate_function = lambda minibatch: Batch.from_examples(minibatch, numericalizer, device=device, + paired=paired and train, max_pairs=max_pairs, groups=iterator.groups) - return torch.utils.data.DataLoader(iterator, - batch_size=None, - collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer, - device=device, paired=paired and train, - max_pairs=max_pairs, groups=iterator.groups)) + return torch.utils.data.DataLoader(iterator, batch_size=None, collate_fn=collate_function) def pad(x, new_channel, dim, val=None): diff --git a/tests/test.sh b/tests/test.sh index 85fe7bd2..ec0ad5c1 100755 --- a/tests/test.sh +++ b/tests/test.sh @@ -39,7 +39,7 @@ for hparams in \ do # train - pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit + pipenv run python3 -m genienlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit # greedy decode pipenv run python3 -m genienlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir @@ -63,7 +63,7 @@ for hparams in \ do # train - pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $embedding_dir --no_commit + pipenv run python3 -m genienlp train --train_tasks almond_multilingual --train_languages fa+en --eval_languages fa+en --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit # greedy decode # combined evaluation