From 5f645e5c9c45fd57be76d7ec3d432abf826e0723 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 11 Mar 2021 11:33:43 -0800 Subject: [PATCH] batch_postprocess_prediction_ids should return target_ids --- genienlp/tasks/almond_task.py | 7 +++---- genienlp/tasks/base_task.py | 4 ++-- genienlp/validate.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index 87b8fca8..43bbc7a6 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -265,8 +265,8 @@ class BaseAlmondTask(BaseTask): token_freqs = [[default_val] * default_size] * len(tokens) return token_freqs - def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs): - return batch_src_ids + def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): + return batch_tgt_ids def postprocess_prediction(self, example_id, prediction): @@ -566,9 +566,8 @@ class Translate(NaturalSeq2Seq): def postprocess_prediction(self, example_id, prediction): return super().postprocess_prediction(example_id, prediction) - def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs): + def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): - batch_tgt_ids = kwargs.pop('batch_tgt_ids') numericalizer = kwargs.pop('numericalizer') cross_attentions = kwargs.pop('cross_attentions') diff --git a/genienlp/tasks/base_task.py b/genienlp/tasks/base_task.py index ae50ad6e..cee4b4b1 100644 --- a/genienlp/tasks/base_task.py +++ b/genienlp/tasks/base_task.py @@ -69,8 +69,8 @@ class BaseTask: """ return generic_dataset.JSON.splits(root=root, name=self.name, **kwargs) - def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs): - return batch_src_ids + def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): + return batch_tgt_ids def postprocess_prediction(self, example_id, prediction): return prediction diff --git a/genienlp/validate.py b/genienlp/validate.py index 425eaa3d..b9ff8abc 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -111,8 +111,8 @@ def generate_with_model(model, data_iterator, numericalizer, task, args, cross_attentions = cross_attentions[-1, ...] # postprocess prediction ids - kwargs = {'batch_tgt_ids': partial_batch_prediction_ids, 'numericalizer': numericalizer, 'cross_attentions': cross_attentions} - partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, **kwargs) + kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions} + partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs) if output_confidence_features or output_confidence_scores: partial_batch_confidence_features = model.confidence_features(batch=batch, predictions=partial_batch_prediction_ids, mc_dropout_num=args.mc_dropout_num)